Overview of RL for LLMs: Algorithms and Scaling
19 Apr 2026This post is a working tour of reinforcement learning for LLMs — the algorithms and the systems that run them at scale. Parts 1–3 cover the math and reference implementation of DPO/PPO/GRPO using TRL as the readable baseline. Parts 4–5 climb up to research-grade variants (REINFORCE++, RLOO, Dr. GRPO, multi-turn agents) using OpenRLHF and discuss reward-modeling pitfalls. Part 6 is the systems side — how verl makes 671B-parameter RL actually run via Ray + vLLM + Megatron, with weight resharding, checkpoint engines, and async rollouts.
- Part 1: The Math
- Part 2: Q&A
- Q: What does β control in DPO, exactly?
- Q: Does the DPO loss have length bias?
- Q: What happens if you remove the reference model from DPO?
- Q: Isn’t “contrastive loss + explicit KL penalty” equivalent to DPO?
- Q: Why does the Critic (value function) only look at the future? Why not past + future?
- Q: In PPO, is the KL penalty in the loss or in the reward?
- Q: In the PPO clipped surrogate, why is it called “surrogate”?
- Q: If
π_θ / π_old = 1, can the gradient still flow? - Q: If a lot of tokens get clipped, we learn nothing on them, right?
- Q: The GAE “effective horizon” — how far back does reward actually propagate?
- Q: What is the Critic’s loss? Why not just regress directly on the final reward?
- Q: Why does Schulman’s k₃ KL estimator have
(r−1) − log rinstead of justlog r? - Q: What does the Critic’s value look like for a “correct but verbose” response?
- Q: Is GRPO’s scalar trajectory reward actually optimizable?
- Part 3: Implementation
- Part 4: OpenRLHF — Research-Grade Scale
- 4.1 What OpenRLHF adds over TRL
- 4.2 The Algorithm Zoo
- 4.3 Multi-Turn and the Agent Abstraction
- 4.4 Research-Grade Tricks
- 4.4.1 DAPO overlong penalty (length control in the reward)
- 4.4.2 ProRL stop-properly penalty (truncation penalty)
- 4.4.3 GSPO — sequence-level importance sampling
- 4.4.4 vLLM IS correction: TIS, ICEPOP, seq-mask-TIS
- 4.4.5 Dual-clip PPO
- 4.4.6 DAPO dynamic filtering
- 4.4.7 No-std-norm and reference-free training
- 4.5 The Key Takeaways
- Part 5: Reward Modeling — Scale, Overconfidence, Calibration
- Part 6: verl — Production-Scale RL Systems
- 6.1 The HybridFlow Controller — Separating Control from Computation
- 6.2 Hybrid Engine — Per-Tensor Weight Resharding
- 6.3 Checkpoint Engine — A Unified Weight Sync Abstraction
- 6.4 Async Training — Recovering the Long Tail
- 6.5 AgentLoop — Multi-Turn and SWE-Style Tool RL
- 6.6 The Full Picture: Fully-Async Multi-Turn Tool-Use Architecture
- 6.7 Transfer Queue — Distributed Data Pool
- 6.8 Q3 Roadmap as a Map
- 6.9 The Practical Picture
- 6.10 Pseudocode
- 6.11 Key Takeaways
- Part 7: Diagnosing RL Jobs — A Practitioner’s Runbook
Part 1: The Math
All three algorithms share the same high-level goal: make the policy π_θ more likely to produce preferred responses, while keeping it close to a reference model π_ref to prevent language collapse. They differ in how they get the preference signal and how they enforce the closeness constraint.
1.1 DPO: Direct Preference Optimization
DPO works on a static dataset of preference pairs (x, y_w, y_l) — a prompt and a chosen vs. rejected completion. No reward model, no rollouts, no value function.
Implicit reward. The DPO paper’s key derivation: the solution of the KL-constrained RL problem
\[\max_{\pi} \; \mathbb{E}_{y \sim \pi}[r(x,y)] - \beta \, D_{\text{KL}}(\pi \,\|\, \pi_{\text{ref}})\]has the closed form
\[r(x, y) = \beta \log \frac{\pi_\theta(y \mid x)}{\pi_{\text{ref}}(y \mid x)} + \beta \log Z(x)\]Bradley-Terry → loss. Plug that reward into the Bradley-Terry preference model $P(y_w \succ y_l) = \sigma(r_w - r_l)$. The partition function Z(x) cancels because it appears in both r_w and r_l:
It’s binary cross-entropy over the log-ratio margin. β controls how far the policy may drift from π_ref — implicitly, not as a separate KL term.
1.2 PPO: Proximal Policy Optimization
PPO is online: the policy generates rollouts, a reward model scores them, and a critic (value model) provides per-token baselines.
Per-token reward (the KL penalty is folded in per-token, with the RM score added at the last token T):
GAE advantage (reverse scan from T to 0, typically γ=1, λ=0.95):
Clipped surrogate loss with importance ratio $\rho_t(\theta) = \pi_\theta(y_t \mid \cdot) / \pi_{\text{old}}(y_t \mid \cdot)$:
\[\boxed{\; L^{\text{PPO}}_\pi = -\mathbb{E}_t\!\left[\min\!\big(\rho_t \hat{A}_t,\; \text{clip}(\rho_t, 1{-}\epsilon, 1{+}\epsilon) \hat{A}_t\big)\right] \;}\]Value loss (also clipped, mirroring the policy clip):
\[L^{\text{PPO}}_V = \tfrac{1}{2} \mathbb{E}_t\!\left[\max\!\big((V_\psi(s_t) - G_t)^2,\; (\text{clip}(V_\psi, V_{\text{old}} \pm \epsilon_v) - G_t)^2\big)\right], \quad G_t = \hat{A}_t + V_{\text{old}}(s_t)\]Total loss: $L^{\text{PPO}} = L^{\text{PPO}}_\pi + c_v L^{\text{PPO}}_V$.
1.3 GRPO: Group Relative Policy Optimization
GRPO drops the critic. For each prompt, sample G rollouts, score them, and use the group mean as the baseline.
Group-relative advantage (applied uniformly to every token in response i):
Clipped surrogate — same shape as PPO, but A_i is a scalar per sequence, not per token:
KL penalty using Schulman’s $k_3$ estimator (unbiased and non-negative per-sample):
\[\text{KL}_t \approx \exp(\log \pi_{\text{ref}} - \log \pi_\theta) - (\log \pi_{\text{ref}} - \log \pi_\theta) - 1\]The final GRPO loss adds β · KL_t to each token’s loss. No value model, no GAE.
Part 2: Q&A
Q: What does β control in DPO, exactly?
β plays a dual role: it’s both the preference strength and the implicit KL constraint — unified in a single knob. In the boxed DPO loss, a larger β amplifies the margin (logratio_w − logratio_l), pushing the model to sharpen its preferences; a smaller β pulls the policy gently and lets it stay close to π_ref. The DPO paper derives this from the KL-constrained RL objective — so β literally is the Lagrange multiplier on the KL constraint, not just a heuristic.
Q: Does the DPO loss have length bias?
Yes — famously. Two sources:
- Data-level: annotators (and LLM judges) tend to prefer longer responses.
- Algorithmic: the “log-ratio” is a sum over tokens: $\log \pi_\theta(y\mid x) = \sum_t \log \pi_\theta(y_t \mid \cdot)$. Longer sequences have more surface area to accumulate positive terms, even if the per-token improvement is tiny.
Sharp follow-up: “Doesn’t the probability of a long response shrink due to the multiplication?” Yes — but DPO cares about the log-ratio against π_ref, not the absolute probability. Both models shrink, and the ratio measures the difference per token. If the current model is slightly better than the reference at every token, those small gains accumulate with length:
Short: 10 tokens × (+0.1) = +1.0 total
Long: 100 tokens × (+0.1) = +10.0 total ← DPO sees this as "much better"
Fix: length-normalize the log-ratio by dividing by |y|, or filter the preference data for length balance.
Q: What happens if you remove the reference model from DPO?
The loss collapses into pure contrastive learning on log-probabilities:
\[L = -\log \sigma\!\left(\beta \log \pi_\theta(y_w \mid x) - \beta \log \pi_\theta(y_l \mid x)\right)\]This is catastrophically unstable for LLMs:
- Mode collapse: the easiest way to push
π(y_w) ≫ π(y_l)is to dump all probability mass onto the winning token sequence. The model stops being a language model and becomes a lookup table. - Gibberish escape hatch: the model can destroy
π(y_l)by making its grammar nonsensical. Withoutπ_refas a “natural language anchor”, there’s nothing stopping it.
The magic of DPO is that the log-ratio against π_ref implicitly enforces the KL constraint without any separate penalty term.
Q: Isn’t “contrastive loss + explicit KL penalty” equivalent to DPO?
Mathematically yes — you’ve reinvented the RLHF objective, just with a supervised-style gradient instead of an RL gradient. But in practice:
- DPO has one unified term (
βcontrols both margin and KL); explicit KL has two competing terms that are hard to balance. - DPO’s
σ(·)form naturally handles the varying-length issue; a raw KL over 2048 tokens is noisy. - DPO avoids needing to sample from the model during training (the whole point — it’s offline).
So: theoretically equivalent, practically worse.
Q: Why does the Critic (value function) only look at the future? Why not past + future?
In RL, the past is sunk cost. The agent’s decision at step t should only depend on what it can still influence. If you baked past rewards into V(s_t):
- An agent with high accumulated rewards would become “lazy” (gradient shrinks — future rewards look like a rounding error).
- An agent in debt might become “reckless” (huge negative baseline distorts everything).
By looking only at the future, the agent remains perfectly rational at every state. This is the Markov property: the current state contains all information needed to choose the next action — the path taken to get there is irrelevant.
Human accounting adds past + future (“net worth”). RL only cares about potential (“net present value of remaining rewards”). Two different concepts that share the word “value”.
Q: In PPO, is the KL penalty in the loss or in the reward?
In the reward, per-token:
\[r_t = \underbrace{-\beta(\log \pi_\theta - \log \pi_{\text{ref}})}_{\text{KL, every token}} + \mathbb{1}[t=T] \cdot \text{RM score}\]This way the Critic naturally learns that “drifting from π_ref” costs value, and GAE distributes the final RM reward and the dense KL costs across all tokens.
Contrast with DPO, where the KL is baked into the loss function via the log-ratio (there is no per-token reward). Contrast also with GRPO, which can put KL in either place — TRL’s default is to add it as a per-token penalty inside the loss, not inside a reward signal.
Q: In the PPO clipped surrogate, why is it called “surrogate”?
Because we’re not directly optimizing the true objective $\mathbb{E}{\pi\theta}[R]$ — that would require fresh rollouts for every weight update. Instead, we optimize
\[\mathbb{E}_{\pi_{\text{old}}}\!\left[\frac{\pi_\theta}{\pi_{\text{old}}} \hat{A}\right]\]using importance sampling — the old rollouts are reweighted as if they came from the new policy. That’s a surrogate for the true expected reward, valid only in a small neighborhood around π_old. The clip keeps us inside that neighborhood; once ρ_t walks outside [1-ε, 1+ε], the gradient is flattened so we stop trusting the approximation.
It’s also called a surrogate because it replaces the second-order Trust Region math (TRPO) with a much simpler first-order clipping heuristic — a “poor man’s trust region”.
Q: If π_θ / π_old = 1, can the gradient still flow?
Yes. At the very first PPO inner step, the ratio is exactly 1 everywhere (we haven’t taken a step yet). But the gradient is
\[\nabla_\theta L = -\hat{A}_t \cdot \nabla_\theta \log \pi_\theta(y_t \mid \cdot)\]The ratio value 1 just means no clipping is active — the magnitude of the update is driven by Â_t. As long as the advantage is non-zero, gradient flows and the model moves. On the next step the ratio diverges from 1 and clipping potentially kicks in.
Q: If a lot of tokens get clipped, we learn nothing on them, right?
Right. Once ρ_t > 1+ε with positive advantage (or ρ_t < 1-ε with negative advantage), the min selects the clipped branch — which is a constant in θ, so its gradient is zero. Those tokens contribute no signal for the rest of the epoch.
This is a feature: it prevents a single “good” token from hogging all the updates and collapsing the distribution. But if your clip fraction is 80%+, your learning rate is too high or you’re doing too many PPO epochs per rollout — most tokens are hitting the wall and nothing is learning.
Q: The GAE “effective horizon” — how far back does reward actually propagate?
\[\hat{A}_t = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}, \qquad \text{effective horizon} \approx \frac{1}{1 - \gamma\lambda}\]For RLHF defaults γ=1, λ=0.95: horizon ≈ 20 tokens. But if only the last token has nonzero reward (and no dense KL penalty), then for a 200-token sequence, tokens 1–180 would see zero. They’d never learn.
Two things rescue the long-horizon case:
- γ = 1 for LLMs, not the usual RL default of 0.99 — because we care whether the final answer is correct, not whether it comes fast.
- The Critic propagates signal. Even if GAE’s direct lookback is ~20 tokens, the Critic’s learned
V(s_{t+1})absorbs downstream rewards and relays them backward. Each training step, the value signal shifts one position earlier. After many steps, even token 1 “knows” about the final reward via the Critic’s learned value surface.
GRPO handles this differently: no Critic, but the scalar reward is applied equally to every token in the sequence — effectively infinite horizon, at the cost of granular credit assignment.
Q: What is the Critic’s loss? Why not just regress directly on the final reward?
The Critic minimizes MSE to a GAE target, not the raw final reward:
\[L_V = \tfrac{1}{2}(V_\psi(s_t) - G_t)^2, \qquad G_t = \hat{A}_t + V_{\text{old}}(s_t)\]Why not $G_t = R$ (the Monte-Carlo return)?
- Variance: different rollouts for the same prompt yield wildly different
R. The Critic’s target would swing from +2.0 to +0.2 batch to batch, and MSE gradients explode. - Bootstrapping: GAE’s target is $V_{\text{old}} + \hat{A}$, which means “my old prediction, plus the surprise”. This is self-correcting — we only move the Critic where it was actually wrong, by exactly that amount of wrongness.
- Speed: bootstrapping lets the Critic learn from partial sequences via TD errors
δ_t = r_t + γV(s_{t+1}) − V(s_t), not just complete episodes.
Intuitively: V + A = Expected Future Reward. The advantage is the surprise (“reality minus old guess”), and adding it back to the old guess produces an updated expectation — exactly what a Critic should predict.
Q: Why does Schulman’s k₃ KL estimator have (r−1) − log r instead of just log r?
The naive estimator $k_1 = \log(\pi_\theta / \pi_{\text{ref}})$ is unbiased but can be negative for individual samples (it’s only guaranteed non-negative in expectation, by Gibbs’s inequality, because the positive terms get weighted more in the sum).
A negative KL sample acts like a reward for diverging — exactly backwards. In PPO/GRPO this adds training noise; in the extreme it flips the sign of the penalty.
k₃ derivation. Use the trivial identity $\mathbb{E}{\pi\theta}[\pi_{\text{ref}}/\pi_\theta - 1] = 0$ (because the expectation of any probability ratio sampled from the denominator is 1). So we can add this “free zero” to the naive estimator without changing its expectation:
\[k_3 = \underbrace{(r - 1)}_{\text{free zero in expectation}} - \underbrace{\log r}_{\text{naive KL}}, \qquad r = \frac{\pi_{\text{ref}}}{\pi_\theta}\]By careful choice of this particular zero-mean term, $k_3$ becomes:
- Always ≥ 0 at the sample level.
f(r) = (r-1) - log ris convex with a minimum of 0 atr=1. - Lower variance near
r = 1: Taylor expansion givesk_3 ≈ ½(r-1)², a quadratic — much smoother than the log. - Unbiased: still has the same expectation as
k_1.
So $k_3 \neq \log r$ pointwise, but $\mathbb{E}[k_3] = \mathbb{E}[k_1] = \text{KL}$, with much better gradient behavior. This is why GRPO (which has no Critic to smooth variance) uses k_3 for its KL term.
// follow up what is kl vs reverse kl, which version we should use?
Q: What does the Critic’s value look like for a “correct but verbose” response?
Two regimes, depending on whether π_ref is also verbose:
If π_ref is concise (KL penalty is alive): Value rises as the answer becomes clear, peaks at the end-of-answer token, then decays during the verbose tail because each extra token pays a KL tax. The advantage for yapping tokens is negative. The model learns to stop.
If π_ref is also verbose (KL penalty ≈ 0 for yapping): Value plateaus at the peak and stays flat through the verbose tail, because no cost is being subtracted. The advantage of each extra token is V_{t+1} − V_t ≈ 0 → zero gradient → the model has no signal to stop. This is why a rambling SFT reference makes PPO “blind” to verbosity, even though the Critic’s math is correct.
The fix is either a concise reference, a length penalty baked into the RM, or a system prompt like “be concise” that creates a π_θ ≠ π_ref gap to re-enable the KL penalty.
Q: Is GRPO’s scalar trajectory reward actually optimizable?
Yes — but only because of relative ranking. With G rollouts for one prompt:
- Raw scalar:
R_ialone carries no gradient signal (everything is compared against what?). - Group-relative:
A_i = (R_i - mean) / stdcreates a signed signal — positive for winners, negative for losers, averaging to zero. Every token in a winning rollout gets its log-prob pushed up; every token in a loser gets pushed down.
The catch: this is blunt credit assignment. If a rollout has a brilliant paragraph ruined by a typo at the end, every token gets the same negative advantage — including the good ones. GRPO compensates with large group sizes and lots of prompts, so noise averages out.
It works well for tasks where rewards are roughly chain-linked (math, code): if the end is wrong, the middle was probably also wrong, so penalizing the whole trajectory is approximately correct. It works less well for tasks where local quality varies sharply within a trajectory — that’s where PPO + GAE shines.
Part 3: Implementation
Here’s how each algorithm looks in TRL, stripped to the default path (no vLLM, no PEFT, no VLM, no tool use). These simplified versions live at trl/trainer/{grpo,dpo}_trainer_math.py and trl/experimental/ppo/ppo_trainer_math.py.
3.1 DPO
The whole trainer is a single compute_loss override — DPO is offline and supervised-style, so no rollouts needed.
def compute_loss(self, model, inputs, ...):
# Batch is structured as [chosen_0,...,chosen_B, rejected_0,...,rejected_B].
# Shape symbols: B = pairs per batch, L = padded sequence length, V = vocab size.
#
# PADDING: each row is [prompt, completion, PAD, ..., PAD] — right-padded.
# No generation step here, so layout is purely for batching. attention_mask
# zeroes the PAD tail; completion_mask additionally zeroes the prompt head.
input_ids = inputs["input_ids"] # (2B, L)
completion_mask = inputs["completion_mask"] # (2B, L) 0=prompt, 1=completion
attention_mask = inputs["attention_mask"] # (2B, L) 0=PAD (right tail)
# 1. Policy forward pass over both chosen and rejected
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
shift_logits = outputs.logits[:, :-1, :] # (2B, L-1, V)
shift_labels = input_ids[:, 1:] # (2B, L-1)
shift_completion_mask = completion_mask[:, 1:] # (2B, L-1)
per_token_logps = selective_log_softmax(shift_logits, shift_labels) # (2B, L-1)
per_token_logps[shift_completion_mask == 0] = 0.0 # mask prompt tokens → 0
logps = per_token_logps.sum(dim=1) # (2B,) sum over completion
chosen_logps, rejected_logps = logps.chunk(2, dim=0) # (B,), (B,)
# 2. Reference forward pass (frozen, no grad)
with torch.no_grad():
ref_outputs = self.ref_model(input_ids=input_ids, attention_mask=attention_mask)
ref_per_token_logps = selective_log_softmax(
ref_outputs.logits[:, :-1, :], shift_labels) # (2B, L-1)
ref_per_token_logps[shift_completion_mask == 0] = 0.0
ref_logps = ref_per_token_logps.sum(dim=1) # (2B,)
ref_chosen_logps, ref_rejected_logps = ref_logps.chunk(2, 0) # (B,), (B,)
# 3. Log-ratios and sigmoid loss
chosen_logratios = chosen_logps - ref_chosen_logps # (B,)
rejected_logratios = rejected_logps - ref_rejected_logps # (B,)
delta_score = chosen_logratios - rejected_logratios # (B,)
loss = -F.logsigmoid(self.beta * delta_score).mean() # scalar
return loss
Key points:
- The chosen and rejected sequences are packed into a single
(2B, L)batch and split with.chunk(2, dim=0). This gets a single forward pass for both, keyed by the chunk ordering. selective_log_softmax(logits, labels)is equivalent togather(log_softmax(logits), dim=-1, index=labels.unsqueeze(-1))but more memory-efficient — it only materializes the log-prob of the actual label token, not the whole vocabulary distribution.per_token_logps.sum(dim=1)is the unnormalized-by-length sequence log-prob. This is where the length bias enters: longer sequences accumulate more terms.
3.2 GRPO
GRPO is online but drops the critic. The training step has two phases: generate (no grad) and loss (with grad), with the outputs of phase 1 cached for multiple iterations of phase 2.
Phase 1 — rollout and scoring (shape symbols: B = unique prompts in the batch, G = num_generations, so N = B*G total sequences after the RepeatSampler has copied each prompt G times; P = padded prompt length, C = padded completion length, V = vocab size, F = number of reward functions):
def _generate_and_score_completions(self, inputs):
prompts = [x["prompt"] for x in inputs] # list of length N
num_generations = self.num_generations # G
# 1. Tokenize and LEFT-PAD prompts (each unique prompt already duplicated G times).
# PADDING: prompts are LEFT-padded so every row's last real token sits at column
# P-1. This makes .generate() start producing the first new token at a consistent
# index across the batch, which simplifies slicing later.
# Row layout: [PAD, PAD, PAD, p, p, p, p] ← all rows "end" at col P-1
prompt_ids_list = self.processing_class.apply_chat_template(prompts, ...)
padded_prompt_ids = pad(..., padding_side="left") # (N, P) PAD on the left
prompt_mask = pad(ones_like(...), ...) # (N, P) 1=real, 0=left-PAD
# 2. Generate G completions per prompt
with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped:
with torch.no_grad():
prompt_completion_ids = unwrapped.generate(
input_ids=padded_prompt_ids,
attention_mask=prompt_mask,
generation_config=self.generation_config,
) # (N, P+C)
completion_ids = prompt_completion_ids[:, P:] # (N, C)
# 3. Mask everything after the first EOS.
# PADDING: completions are RIGHT-padded — each row generated up to its own EOS
# then continues with PAD tokens to fill to length C. We mask out post-EOS PADs.
# Row layout: [c, c, c, EOS, PAD, PAD] ← short rows pad on the right
is_eos = completion_ids == self.eos_token_id # (N, C)
eos_idx = first_true_index_per_row(is_eos) # (N,)
completion_mask = (arange(C) <= eos_idx[:, None]).int() # (N, C) 1=real, 0=right-PAD
# 4. Compute old_per_token_logps (π_old) — needed for the importance ratio later
# (When training and generation are step-aligned, this equals the current policy
# and we can skip it, using per_token_logps.detach() in compute_loss instead.)
# PADDING: concatenating left-padded prompts with right-padded completions yields
# [LEFT-PAD, prompt, completion, RIGHT-PAD]
# full_mask is 0 at both PAD ends and 1 in the middle — correctly drives attention.
full_ids = torch.cat([padded_prompt_ids, completion_ids], dim=1) # (N, P+C)
full_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (N, P+C)
logits_to_keep = C
with torch.no_grad():
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model, full_ids, full_mask, logits_to_keep) # (N, C)
# Reference model logps for KL (only if β != 0)
if self.beta != 0.0:
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.ref_model, full_ids, full_mask, logits_to_keep) # (N, C)
# 5. Score with reward functions (callables or AutoModelForSequenceClassification)
rewards_per_func = self._calculate_rewards(
inputs, prompts, completions, completion_ids_list) # (N, F)
rewards = (rewards_per_func * self.reward_weights).sum(dim=1) # (N,)
# 6. Group-relative advantage — reshape to (B, G), normalize within each row
rewards_grouped = rewards.view(-1, G) # (B, G)
mean = rewards_grouped.mean(dim=1) # (B,)
std = rewards_grouped.std(dim=1) # (B,)
mean = mean.repeat_interleave(G) # (N,)
std = std.repeat_interleave(G) # (N,)
advantages = (rewards - mean) / (std + 1e-4) # (N,)
return {prompt_ids: padded_prompt_ids, # (N, P)
completion_ids: completion_ids, # (N, C)
completion_mask: completion_mask, # (N, C)
advantages: advantages, # (N,)
old_per_token_logps: old_per_token_logps, # (N, C)
ref_per_token_logps: ref_per_token_logps} # (N, C) if β != 0
unwrap_model_for_generation(...)is a context manager fromtrl.modelsthat hands back a plainnn.Modulesuitable for.generate()— under FSDP/DeepSpeed ZeRO-3 it gathers the sharded parameters onto each rank for the duration of the block, and it also temporarily overridesgeneration_configwith training-time sampling kwargs.
Phase 2 — clipped surrogate loss (B here = per-device micro-batch size — the N generation batch from Phase 1 has been sliced across steps_per_generation update steps):
def compute_loss(self, model, inputs):
input_ids = torch.cat([inputs["prompt_ids"],
inputs["completion_ids"]], dim=1) # (B, P+C)
attention_mask = torch.cat([inputs["prompt_mask"],
inputs["completion_mask"]], dim=1) # (B, P+C)
logits_to_keep = inputs["completion_ids"].size(1) # C
mask = inputs["completion_mask"] # (B, C)
advantages = inputs["advantages"].unsqueeze(1) # (B, 1) broadcasts over C
# Current policy logps π_θ (entropy returned but only used for logging)
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
model, input_ids, attention_mask, logits_to_keep,
compute_entropy=True) # (B, C), (B, C)
# Importance ratio ρ_t = π_θ / π_old
old_logps = inputs.get("old_per_token_logps",
per_token_logps.detach()) # (B, C)
ratio = torch.exp(per_token_logps - old_logps) # (B, C)
# Clipped surrogate — identical structure to PPO
clipped = torch.clamp(ratio, 1 - self.epsilon_low,
1 + self.epsilon_high) # (B, C)
per_token_loss = -torch.min(ratio * advantages,
clipped * advantages) # (B, C)
# k3 KL penalty against reference (only if β != 0)
if self.beta != 0.0:
ref_logps = inputs["ref_per_token_logps"] # (B, C)
diff = ref_logps - per_token_logps # (B, C)
per_token_kl = torch.exp(diff) - diff - 1 # (B, C) ≥ 0
per_token_loss = per_token_loss + self.beta * per_token_kl # (B, C)
# Masked mean: average over completion tokens, then over sequences
seq_loss = (per_token_loss * mask).sum(-1)
/ mask.sum(-1).clamp(min=1.0) # (B,)
loss = seq_loss.mean() # scalar
return loss
Key points:
- Advantages are scalars per sequence (shape
(B,)→ unsqueezed to(B, 1)), broadcast across all tokens. Every token in responseigets the sameA_i— the “blunt credit assignment” we discussed. self._get_per_token_logpsre-uses the same utility as DPO’sselective_log_softmaxwith a shift, temperature divide, and keep-last-K slice.k_3 = exp(log_ref - log_θ) - (log_ref - log_θ) - 1computes Schulman’s non-negative KL estimator per token — nomeanorsumover the vocab, just these two logprobs.- The loss mask is the completion mask: we only apply loss on tokens the policy generated, not on the prompt.
3.3 PPO
PPO has the most moving parts because of the Critic. The entire training loop is a single monolithic train() method (not split into training_step + compute_loss). It cycles through rollout → reward processing → GAE → multi-epoch update.
Shape symbols in this block: B = per-device rollout batch size in Phase 1/2, P = prompt length (context_length), C = response length (gen_length), V = vocab size. Phase 3 uses B again for the per-device micro-batch inside the inner update loop (same letter, smaller slice of the rollout).
def train(self):
for update in range(num_total_batches):
data = next(iter_dataloader)
# =================== Phase 1: ROLLOUT (no grad) ===================
with torch.no_grad():
# PADDING: prompts come in LEFT-padded from DataCollatorWithPadding
# (tokenizer.padding_side = "left"). Layout:
# [PAD, ..., PAD, p, p, p] ← all rows end at column P-1
# batch_generation handles the mask/position_ids internally so generation
# produces contiguous new tokens after column P-1 on every row.
queries = data["input_ids"] # (B, P) left-PAD
context_length = queries.shape[1] # P
# Generate responses via unwrapped policy. `self.model` here is a
# PolicyAndValueWrapper, so the unwrapped object has `.policy` and
# `.value_model`; we pass `.policy` to batch_generation.
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped:
query_responses, logits = batch_generation(
unwrapped.policy, queries,
args.local_rollout_forward_batch_size,
pad_token_id, generation_config,
)
# PADDING: query_responses = [left-PAD, prompt, completion, right-PAD].
# Completions differ in length (EOS arrives at different columns), so the
# tail after each row's EOS is right-padded to a common length C.
# query_responses: (B, P+C) logits: (B, C, V) (from generation scores)
responses = query_responses[:, context_length:] # (B, C) right-PAD tail
# Policy logps (from generation logits, divided by temperature)
logprobs = selective_log_softmax(
logits / temperature, responses) # (B, C)
# Reference model logps (for KL penalty) — separate forward pass
ref_output = forward(ref_policy, query_responses, pad_token_id)
ref_logits = ref_output.logits[:, P-1 : -1] / temperature # (B, C, V)
ref_logprobs = selective_log_softmax(ref_logits, responses) # (B, C)
# Value predictions V(s_t) from the Critic
full_value, _, _ = get_reward(unwrapped_value_model,
query_responses, ...) # (B, P+C, 1)
values = full_value[:, P-1 : -1].squeeze(-1) # (B, C)
# Reward model score at the last token (truncated response)
postprocessed = truncate_response(
stop_token_id, pad_token_id, responses) # (B, C)
_, scores, _ = get_reward(
reward_model, postprocessed_query_response, ...) # scores: (B,)
# Missing-EOS penalty
if args.missing_eos_penalty is not None:
contain_eos = (postprocessed == eos_token_id).any(dim=-1) # (B,)
scores[~contain_eos] -= args.missing_eos_penalty
# ============== Phase 2: PER-TOKEN REWARD + GAE ==============
# KL reward per token (k1 or k3 estimator)
logr = ref_logprobs - logprobs # (B, C)
kl = -logr if args.kl_estimator == "k1" \
else (logr.exp() - 1) - logr # (B, C)
non_score_reward = -args.kl_coef * kl # (B, C)
rewards = non_score_reward.clone() # (B, C)
# Add scalar RM score at the last non-padded token of each row
rewards[batch_idx, seq_len + 1] += scores # scatter (B,) → (B, C)
# GAE reverse scan: δ_t = r_t + γV_{t+1} - V_t
# A_t = δ_t + γλ A_{t+1}
lastgaelam = 0
advantages_rev = []
for t in reversed(range(C)):
nextvalues = values[:, t + 1] if t < C - 1 else 0.0 # (B,)
delta = rewards[:, t] + args.gamma * nextvalues \
- values[:, t] # (B,)
lastgaelam = delta + args.gamma * args.lam * lastgaelam # (B,)
advantages_rev.append(lastgaelam)
advantages = torch.stack(advantages_rev[::-1], dim=1) # (B, C)
returns = advantages + values # (B, C) — Critic target
advantages = masked_whiten(advantages, ~padding_mask) # (B, C) zero-mean, unit-var
# ============== Phase 3: PPO UPDATE (multi-epoch) ==============
# All mb_* tensors are FROZEN slices of Phase 1/2 outputs — they don't
# update during the inner loop. Only new_logprobs and vpred (from the
# current model) get recomputed each step.
# mb_logprobs = log π_old(y_t | s_t) (from rollout generation scores)
# mb_advantage = Â_t (whitened, from GAE in Phase 2)
# mb_return = G_t = Â_t + V_old (Critic target — fixed)
# mb_values = V_old(s_t) (Critic snapshot at rollout)
for ppo_epoch in range(args.num_ppo_epochs):
for minibatch in random_shuffle(data):
# mb_query_responses: (B, P+C), where B = per_device_train_batch_size.
# Forward through policy + value in one call (PolicyAndValueWrapper).
output, vpred_raw = forward(model, mb_query_responses, pad_token_id)
# output.logits: (B, P+C, V) vpred_raw: (B, P+C, 1)
new_logits = output.logits[:, P-1 : -1] / temperature # (B, C, V)
new_logprobs = selective_log_softmax(new_logits, mb_responses) # (B, C)
vpred = vpred_raw[:, P-1 : -1].squeeze(-1) # (B, C)
# === Clipped value loss ===
# L_V = 0.5 · max((V - G)², (clip(V, V_old ± ε_v) - G)²)
vpred_clipped = torch.clamp(
vpred, mb_values - args.cliprange_value,
mb_values + args.cliprange_value) # (B, C)
vf_losses = (vpred - mb_return) ** 2 # (B, C)
vf_losses2 = (vpred_clipped - mb_return) ** 2 # (B, C)
vf_loss = 0.5 * masked_mean(
torch.max(vf_losses, vf_losses2),
~padding_mask_p1) # scalar
# === Clipped policy surrogate ===
# L_π = max(-A·ρ, -A·clip(ρ, 1±ε))
# equivalent to -min(A·ρ, A·clip(ρ)) — PPO's standard form
ratio = torch.exp(new_logprobs - mb_logprobs) # (B, C)
pg_losses = -mb_advantage * ratio # (B, C)
pg_losses2 = -mb_advantage * torch.clamp(
ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) # (B, C)
pg_loss = masked_mean(
torch.max(pg_losses, pg_losses2), ~padding_mask) # scalar
# Total loss — NO entropy bonus in this implementation
loss = pg_loss + args.vf_coef * vf_loss # scalar
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
Key points:
PolicyAndValueWrapper: a tinynn.Modulethat holds both the policy and value model, so a singleaccelerator.prepare(model)handles both. Itsforwardreturns(policy_output, value_logits)in one call — the two models share the backbone forward pass.- KL goes in the reward, not the loss (see Q&A #6).
non_score_reward = -kl_coef * klis added to every token; the RM scalar is added only at the final token. Then GAE processes the combined reward vector. - Reward whitening (
masked_whiten(advantages, mask)) normalizes advantages to zero-mean unit-variance across the minibatch. This is crucial — raw advantages can be massive (the full RM score) or zero (intermediate tokens), and whitening stabilizes the policy gradient. - Two clip ranges:
cliprangefor the policy ratio (standard PPO),cliprange_valuefor the Critic. Both defend against large single-step changes. - No entropy bonus in this implementation. The KL penalty to
π_refserves a similar role — it keeps the policy from collapsing into a delta distribution. - Multi-epoch inner loop:
num_ppo_epochs×num_mini_batches×gradient_accumulation_stepsupdates per rollout. Each inner step uses the same rollout data with freshly-computedratioagainst the initial rollout-time logprobsπ_old. This is what the “proximal” in PPO buys you — reusing expensive rollouts safely.
3.4 Summary
| Aspect | DPO | GRPO | PPO |
|---|---|---|---|
| Rollout needed? | No (offline) | Yes | Yes |
| Reward model? | No (implicit via preferences) | Yes (or rule-based function) | Yes |
| Critic/value model? | No | No | Yes |
| Baseline for advantage | n/a | Group mean | V(s_t) from Critic |
| KL location | Implicit in loss (log-ratio) | Explicit in loss (k₃ term) | In per-token reward |
| Models in memory | 2 (π_θ, π_ref) | 3 (π_θ, π_ref, RM) | 4 (π_θ, π_ref, RM, V) |
| Credit assignment | Sequence-level | Sequence-level (blunt) | Token-level (via GAE) |
| Main failure mode | Length bias, data noise | Credit assignment noise | Critic instability |
| Paper | 2305.18290 | 2402.03300 | 1707.06347 |
If you know your preferences come from a well-curated static dataset, DPO is the cheapest correct answer. If you have a verifiable reward (math, code, unit tests) and want to exploit it with lots of rollouts, GRPO is simplest. If you want the strongest token-level credit assignment and have the VRAM budget for 4 models, PPO is still the workhorse — but it’s the hardest to tune.
Part 4: OpenRLHF — Research-Grade Scale
TRL is the right reference implementation for the canonical algorithms. OpenRLHF is what you reach for when you need more: a production-scale framework (Ray + vLLM + DeepSpeed) with an algorithm zoo (PPO, REINFORCE++, REINFORCE++-baseline, GRPO, Dr. GRPO, RLOO), multi-turn agent execution, and a stack of research-grade tricks that appeared in 2024–2026 papers (DAPO, ScaleRL, ProRL, GSPO, etc.). This section is about the algorithmic and capability advances — not the distributed-systems plumbing.
4.1 What OpenRLHF adds over TRL
At the framework level:
| Capability | TRL | OpenRLHF |
|---|---|---|
| Algorithms | DPO, PPO, GRPO | + REINFORCE++, REINFORCE++-baseline, RLOO, Dr. GRPO |
| Critic | Always for PPO | Optional for all variants (most are critic-free) |
| Multi-turn rollouts | No first-class support | Native reset/step agent API, tokens preserved across turns |
| Agent execution | n/a | Token-in-token-out unified pipeline (Single/Multi-Turn executors) |
| Length penalties | — | DAPO overlong + ProRL stop-properly reward shaping |
| Dynamic filtering | — | Sampling-level filter by reward score (DAPO) |
| Policy loss variants | PPO-clip | + GSPO (sequence-level IS) |
| Off-policy IS correction | Basic vLLM IS | TIS / ICEPOP / seq-mask-TIS (three strategies) |
| Dual-clip PPO | — | Bounds the negative-advantage branch |
| Reference-free training | n/a (β=0 works) | --algo.kl.init_coef 0 as a first-class path |
| VLM RL | Image-in prompt (experimental) | Full VLM RLHF incl. multi-turn with image feedback |
| Optimizer | AdamW | + Muon (2D weights) with aux-AdamW for 1D/embeddings |
The common thread: everything is a knob you can toggle independently. The RL algorithm, the execution mode (single/multi-turn), the reward shaping, and the IS correction are all orthogonal.
4.2 The Algorithm Zoo
Every OpenRLHF online algorithm follows the same skeleton: sample G = n_samples_per_prompt rollouts per prompt, shape per-group rewards, compute advantages, optionally whiten them, apply PPO-clip with multi-epoch updates. Each algorithm is just a particular assembly of shared building blocks. The cleanest way to understand them is to define those blocks once with their math, then point at which algorithm uses which.
4.2.1 The Building Blocks
Notation. For a given prompt, the policy samples G rollouts producing scalar rewards R_1, ..., R_G. Each rollout is a token sequence of length T. Let $\bar R = \frac{1}{G}\sum_i R_i$ be the group mean and $\sigma_R$ be the group standard deviation.
Block A — Per-group reward shaping. Reshape the scalar reward R_i of each rollout into $\tilde R_i$ before it’s spread across tokens. Four options:
Block B — KL penalty. OpenRLHF supports two placements for the KL term, switched by --algo.kl.use_loss (default False). The flag affects every algorithm, including GRPO and PPO.
- B-reward (default,
use_loss=False) — KL goes into the per-token reward (PPO style, Part 2 Q&A). The shaped scalar $\tilde R$ is added at the last action token; KL is added at every action token:
- B-loss (
use_loss=True) — KL is removed from the reward and added as a separate term to the actor loss (TRL’s GRPO style):
The Schulman estimator (k1/k2/k3, Part 2 Q&A) should match the mode — OpenRLHF’s runtime even prints a recommendation: k1 (the naive log r) when KL is in the reward (per-sample negativity is fine — the trajectory average is unbiased); k2 or k3 (the always-non-negative variants) when KL is in the loss (a per-sample negative value would act as a reward for diverging).
So when this section says “KL-in-reward” elsewhere, that’s the default path. The other path (KL-in-loss with k3) is one flag away and matches what TRL’s GRPO does.
Block C — Advantage estimation. Two choices:
- C1: GAE (with critic). Standard PPO recipe.
- C2: Cumulative return (no critic, γ=1 in RLHF). Every action token gets the same scalar — the trajectory’s discounted return from
tonward, which withγ=1reduces to the constant trajectory-level shaped reward.
Block D — Batch-level advantage whitening. Across all rollouts (all prompts) in a batch, normalize advantages to zero-mean unit-variance:
\[\hat A_t \leftarrow \frac{\hat A_t - \mu_{\text{batch}}}{\sigma_{\text{batch}} + \varepsilon}\](Toggleable via --algo.advantage.no_std_norm, which keeps the mean-subtract but skips the /σ.)
Block E — PPO-clip policy loss. With ratio $\rho_t = \pi_\theta(y_t) / \pi_{\text{old}}(y_t)$:
\[L_\pi = -\mathbb{E}_t\!\left[\min\!\big(\rho_t \hat A_t,\; \text{clip}(\rho_t, 1-\epsilon, 1+\epsilon) \hat A_t\big)\right]\]Block F — Critic value loss (only used together with C1):
\[L_V = \tfrac{1}{2}\mathbb{E}_t\!\left[(V_\psi(s_t) - G_t)^2\right], \qquad G_t = \hat A_t + V_{\text{old}}(s_t)\]Block G — Multi-epoch updates. Reuse one rollout for μ = num_ppo_epochs passes of optimizer steps, freezing $\pi_{\text{old}}$, $\hat A$, $G_t$, and $V_{\text{old}}$ — the “proximal” guarantee from Part 2.
Together: Blocks B, E, G are shared by every online algorithm. The variants differ on three knobs: which A, which C, whether D is on, whether F is on.
4.2.2 Mapping Blocks to Algorithms
| Algorithm | --advantage.estimator |
A: per-group shaping | C: advantage | D: batch whiten | F: critic loss |
|---|---|---|---|---|---|
| PPO | gae |
— (no shaping; critic provides baseline) | C1 (GAE) | ✓ | ✓ |
| REINFORCE++ | reinforce |
A1: identity ($R_i$) | C2 (cumret) | ✓ | ✗ |
| RLOO | rloo |
A3: LOO | C2 (cumret) | ✗ | ✗ |
| REINFORCE++-baseline | reinforce_baseline |
A2: $R_i - \bar R$ | C2 (cumret) | ✓ | ✗ |
| GRPO | group_norm |
A4: $(R_i - \bar R)/\sigma_R$ | C2 (cumret) | ✗ | ✗ |
| Dr. GRPO | dr_grpo |
A2: $R_i - \bar R$ | C2 (cumret) | ✗ | ✗ |
All six share Blocks B, E, G (KL-in-reward, PPO-clip, multi-epoch). PPO is the only one with a critic (and therefore the only one using GAE + value loss).
The corresponding code lives in experience_maker.py:
# Block A — per-group shaping
if estimator == "rloo":
baseline = (rewards.sum(-1, keepdim=True) - rewards) / (G - 1)
rewards = rewards - baseline # A3
elif estimator in ["reinforce_baseline", "dr_grpo"]:
rewards = rewards - rewards.mean(-1, keepdim=True) # A2
elif estimator == "group_norm":
rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9) # A4
# (plain "reinforce" → A1 identity; "gae" → no per-group shaping)
# Block C — advantage
if estimator == "gae":
advantages, returns = get_advantages_and_returns(values, rewards, ...) # C1
else:
returns = get_cumulative_returns(rewards, action_mask, gamma=1.0) # C2
advantages = returns.clone()
# Block D — batch whitening
if estimator in ["gae", "reinforce", "reinforce_baseline"]:
advantages = (advantages - global_mean) / global_std
Three lines of code separate Dr. GRPO from REINFORCE++-baseline (Block D toggle), and one branch separates Dr. GRPO from GRPO (Block A2 vs A4).
4.2.3 Per-Algorithm Notes
-
PPO (
gae) — The classical setup. The critic $V(s)$ provides a learned, state-dependent baseline that lets GAE distribute the final reward back across tokens with token-level credit assignment. Best when you can afford the 4-model VRAM cost and the critic-tuning overhead. -
REINFORCE++ (
reinforce) (arxiv 2501.03262) — REINFORCE plus all the PPO stability tricks except the critic: KL-in-reward (B), PPO-clip (E), multi-epoch (G), batch whitening (D). The whitening is the trick that rescues REINFORCE without a critic — without it, the gradient variance blows up. Used by ScaleRL and (per the README) something close to it by Magistral. Logic-RL and PRIME report it’s “more stable than GRPO, faster than PPO”. -
RLOO (
rloo) (arxiv 2402.14740, “Back to Basics”) — The leave-one-out baseline is unbiased by construction: rollouti’s baseline doesn’t containR_i, so subtracting it doesn’t shift the expected gradient. This is the same jackknife trick from classical statistics. Especially valuable whenGis small (~4) — whenGis large, $\frac{\sum_{j\ne i} R_j}{G-1} \approx \bar R$ and the LOO bias-correction matters less. No batch whitening because the LOO baseline is already variance-reducing, and adding a/σcould distort the unbiased property. -
REINFORCE++-baseline (
reinforce_baseline) — Same as REINFORCE++ plus Block A2 (group-mean subtract). The group mean acts as a per-prompt baseline (“how hard was this prompt for the current policy?”), then the global whitening normalizes scale across prompts of different difficulty. The README’s recommended choice for RLVR (reasoning tasks with verifiable rewards), validated at scale by ProRL V2 and ScaleRL. -
GRPO (
group_norm) — The DeepSeekMath canonical: $(R_i - \bar R)/\sigma_R$. The/σis appealing because it makes advantages scale-free per prompt — but it has a pathology (next entry). -
Dr. GRPO (
dr_grpo) (arxiv 2503.20783, “Understanding R1-Zero-Like Training”) — “Dr.” = “Done Right”. Identifies two biases in vanilla GRPO:- Difficulty bias from
/σ: when allGrollouts succeed (easy prompt) or all fail (hard prompt), $\sigma_R \approx 0$ and the advantages explode in magnitude. These prompts dominate gradients while medium-difficulty prompts (where the actual learning signal lives) are drowned out. Dr. GRPO drops the/σ, keeping just the mean subtract. - Length bias from loss aggregation (a separate fix at the loss-aggregation step, not in advantage shaping): GRPO averages per-token losses within each sequence first, which down-weights longer sequences. Dr. GRPO normalizes by a constant (max generation length).
At the advantage-shaping step, Dr. GRPO and REINFORCE++-baseline are identical (both use A2). The only difference is Block D: REINFORCE++-baseline applies batch whitening; Dr. GRPO does not. Practical implication — pick Dr. GRPO when reward magnitudes are uniform across prompts (e.g. binary 0/1 verifiable rewards) so you don’t need extra rescaling; pick REINFORCE++-baseline when prompts have wildly different reward ranges.
- Difficulty bias from
4.3 Multi-Turn and the Agent Abstraction
Most RL-for-LLMs training is single-turn: one prompt → one response → one reward. Reasoning, tool-use, coding-with-feedback, and interactive games are fundamentally multi-turn: the policy acts, the environment returns an observation, the policy acts again, repeated until done.
The three roles
There are three actors in a multi-turn rollout. Their names in OpenRLHF are slightly confusing because the framework reuses “agent” for the environment, so let’s pin down the vocabulary first:
| Role | Classical RL term | OpenRLHF class | What it does |
|---|---|---|---|
| LLM | Policy / agent | llm_engine (vLLM) |
Generates action tokens given current context |
| Environment | Environment | AgentInstance |
Defines the task: gives initial observation, scores actions, returns feedback |
| Orchestrator | (the loop driver) | MultiTurnAgentExecutor |
Bridges the two — runs the rollout loop and assembles a single token trajectory |
So the “Agent” in OpenRLHF is the environment, wrapped as a Python class. The actual decision-making policy is the LLM. The Executor is the runtime that drives the conversation between them.
How they interact
The Executor is the only caller. It calls generate on the LLM and gets action tokens back; calls step on the Environment and gets (reward, feedback, done) back; appends both to the running observation; loops until done. Time flows downward in this sequence diagram:
Executor LLM Environment
│ │ │
│ reset({prompt, label}) │
├─────────────────────────────────────────────────────▶│
│ │
│ ◀────────────────────────────────────────────────────┤
│ initial observation (text) │
│ │ │
│ ═══ loop until done ═══════════════════════════════════
│ │ │
│ generate(obs_tokens, sampling_params) │
├───────────────────────────▶│ │
│ │ │
│ ◀──────────────────────────┤ │
│ action_tokens, action_text, logprobs │
│ │ │
│ step({action_text, obs_text, label, ...}) │
├─────────────────────────────────────────────────────▶│
│ │
│ ◀────────────────────────────────────────────────────┤
│ {rewards, environment_feedback, done, ...} │
│ │ │
│ obs_tokens += action_tokens + feedback_tokens │
│ total_reward += step["rewards"] │
│ action_ranges.append((start, end)) │
│ if done: break │
▼ ▼ ▼
Then the assembled trajectory — obs_tokens (the full token sequence) plus action_ranges (which slices were LLM-produced) — feeds into the RL algorithm (PPO / REINFORCE++ / GRPO / …) which trains the LLM on its actions only.
The Executor strictly alternates: ask the LLM for an action, hand the action to the Environment, get feedback, append, repeat. It never queries both at once.
Why the agent has two functions
AgentInstance exposes reset and step because that’s the canonical OpenAI-Gym interface for environments — battle-tested in classical RL for over a decade.
class AgentInstance(AgentInstanceBase):
async def reset(self, states: dict):
# Called ONCE at episode start.
# Input: {observation, label, ...} from the dataset row
# Output: {observation: <text>} initial state shown to the LLM
...
async def step(self, state_dict: dict) -> dict:
# Called ONCE PER MODEL TURN, after the LLM emits an action.
# Input: {observation_text, action_text, label, sampling_params}
# Output: {
# "rewards": scalar reward for this step,
# "environment_feedback": text to splice in before next LLM turn,
# "done": bool, terminate the episode?
# "scores": optional 0-1 score for dynamic filtering,
# "extra_logs": anything to log to wandb,
# }
...
Who calls them? The MultiTurnAgentExecutor.execute() method calls both, in fixed order:
agent.reset(states)— once, at the start of the trajectory.llm_engine.generate(obs_tokens, ...)— to produce the next action.agent.step({action, obs, ...})— once, to score that action and produce environment feedback.- Loop back to step 2 with appended tokens, until
done=Trueor the budget is exhausted.
You as a user write the AgentInstance (your task definition) and reset/step are your “API” to the framework. You never call them yourself — the Executor does.
Pseudocode: the multi-turn rollout
This is what MultiTurnAgentExecutor.execute() actually does (simplified, from agent.py):
async def execute(self, prompt, label, sampling_params, max_length, tokenizer, llm_engine):
# 1. Spin up a fresh environment for this rollout.
env = AgentInstance() # ← user-provided class
# 2. Get the initial observation from the environment.
initial_obs = await env.reset({"observation": prompt, "label": label})
obs_text = initial_obs["observation"]
obs_tokens = tokenizer(obs_text, add_special_tokens=False)["input_ids"]
total_reward = 0.0
action_ranges = [] # [(start, end), ...] in token space — the LLM's spans
rollout_logprobs = [] # for IS correction / GSPO
# 3. Multi-turn rollout loop.
while True:
# 3a. Budget check — leave room for at least one more generation.
sampling_params.max_tokens = max_length - len(obs_tokens)
if sampling_params.max_tokens <= 0:
break
# 3b. Ask the LLM for an action (token-in, token-out — no text round-trip).
request_output = await llm_engine.generate(obs_tokens, sampling_params)
action_tokens = request_output.outputs[0].token_ids
action_text = request_output.outputs[0].text
rollout_logprobs.extend(request_output.outputs[0].logprobs) # for training
# 3c. Mark the action's span (used as the loss mask later).
action_start = len(obs_tokens)
action_end = action_start + len(action_tokens)
action_ranges.append((action_start, action_end))
# 3d. Send the action to the environment, get feedback.
step_result = await env.step({
"observation_text": obs_text,
"action_text": action_text,
"label": label,
"sampling_params": sampling_params,
})
total_reward += step_result["rewards"].item()
feedback_text = step_result["environment_feedback"]
done = step_result["done"]
# 3e. Tokenize feedback and splice it onto the running sequence.
feedback_tokens = tokenizer(feedback_text, add_special_tokens=False)["input_ids"]
obs_text = obs_text + action_text + feedback_text
obs_tokens = obs_tokens + action_tokens + feedback_tokens
rollout_logprobs.extend([0.0] * len(feedback_tokens)) # env tokens not from LLM
# 3f. Termination conditions.
if done:
break
# 4. Hand back one trajectory: a flat token sequence with marked action spans.
return {
"observation_tokens": obs_tokens, # [prompt, action_1, fb_1, action_2, fb_2, ...]
"action_ranges": action_ranges, # which slices are LLM-produced
"reward": total_reward, # scalar trajectory reward
"rollout_log_probs": rollout_logprobs, # per-token, 0.0 for env tokens
}
Pseudocode: the outer RL loop
The Executor produces token trajectories. The trainer wraps it in the standard RL training loop:
for update in range(num_updates):
# 1. ROLLOUT — for each prompt, run G trajectories in parallel.
trajectories = []
for prompt, label in batch:
for _ in range(n_samples_per_prompt): # G rollouts per prompt
traj = await agent_executor.execute(
prompt, label, sampling_params, max_length,
tokenizer, llm_engine)
trajectories.append(traj)
# 2. SHAPE PER-GROUP REWARDS (Part 4.2 — pick one):
# rloo : R - leave_one_out_mean
# reinforce_baseline / dr_grpo : R - group_mean
# group_norm (GRPO): (R - group_mean) / group_std
rewards = [t["reward"] for t in trajectories]
advantages = shape_per_group(rewards, estimator) # see §4.2
# 3. OPTIONAL REWARD SHAPING (length penalties etc.)
apply_length_penalties(trajectories, args) # DAPO / ProRL — §4.4.1, 4.4.2
# 4. DYNAMIC FILTERING — drop saturated groups (§4.4.6).
if args.algo.dynamic_filtering_enable:
trajectories, advantages = filter_by_score(trajectories, advantages)
# 5. BUILD TRAINING TENSORS — only LLM-produced tokens contribute to loss.
for traj in trajectories:
traj["action_mask"] = build_mask_from(traj["action_ranges"]) # 1 inside ranges, 0 outside
# 6. POLICY UPDATE — multi-epoch PPO-clip (with optional KL penalty).
for ppo_epoch in range(num_ppo_epochs):
for minibatch in shuffle(trajectories):
new_logprobs = forward_policy(model, minibatch)
loss = ppo_clip_loss(new_logprobs, minibatch.old_logprobs,
minibatch.advantages, minibatch.action_mask)
if args.algo.kl.init_coef > 0:
ref_logprobs = forward_policy(ref_model, minibatch)
loss = loss + beta * k3_kl(ref_logprobs, new_logprobs, minibatch.action_mask)
loss.backward(); optimizer.step()
Three things to notice in the outer loop:
- The RL algorithm choice (PPO/REINFORCE++/GRPO/Dr. GRPO/RLOO) only changes step 2 — the per-group reward shaping. Steps 1, 3, 4, 5, 6 are identical across all variants. This is the “decoupled” promise from the README’s architecture diagram.
- Single-turn is multi-turn with
done=Trueafter step 1. TheSingleTurnAgentExecutoris structurally the same loop, just run once. That’s why the README says single-turn and multi-turn are “orthogonal to RL algorithms” — they share both the executor interface and the trainer downstream. - The action mask is what makes multi-turn work mathematically. Environment-feedback tokens are in the sequence (the LLM saw them as context), but they are not model output, so they’re excluded from both the policy gradient and the KL penalty. The
action_rangesreturned byexecute()is exactly this mask in compressed form.
Why token-in-token-out matters
The defining design principle: everything the LLM sees and produces is a token id, never re-tokenized text. Consequences:
- Zero text-mismatch during training. The token sequence used for the loss is literally the sequence the LLM saw at rollout time. No risk of the tokenizer encoding a string differently the second time around (whitespace, special tokens, byte-pair quirks).
- Delta tokenization. For prefix-sharing turns (OpenAI chat format etc.), only the delta is freshly tokenized; earlier turns’ token ids are reused verbatim across calls.
- VLM consistency. For multi-turn VLM (screenshots in environment feedback), image-pad tokens in the prompt align exactly with
pixel_valuesin the forward pass. Text-level re-tokenization would silently drop image placeholders and break the alignment.
Implications for rewards and training data
A multi-turn rollout accumulates rewards across steps (total_reward += step_result["rewards"]). The way this turns into training signal depends on the algorithm:
- Credit assignment is trajectory-level. Only the final
total_rewardis used as the trajectory’s scalar reward. It’s spread across all action tokens (the union ofaction_ranges), weighted by whatever advantage estimator you chose. Environment-feedback tokens are masked out of the loss. - Sparse intermediate rewards are allowed. You can set
step_result["rewards"]per step and the sum is what matters. Most implementations give 0 mid-rollout and the full reward at the final step. doneis decided by the environment. The budget (max_length) also forces stops. Truncated trajectories can be penalized via the ProRL stop-properly trick (§4.4.2).- Variable-length trajectories per prompt. Group baselines (RLOO, reinforce_baseline, Dr. GRPO) compare rollouts regardless of length — the comparison is over scalar rewards, not over sequences.
done≠ EOS. The LLM’s per-turn generation stops at EOS each round; “done” means the episode terminates. The same prompt can yield 1-turn, 3-turn, 7-turn trajectories.
Dataset implication: a “prompt” is just (prompt, label), no completions. The agent code generates everything downstream. Rewards live in code (agent.step), not in a dataset column — and that’s the only way to score trajectories that branch into different states based on what the policy chose to do.
4.4 Research-Grade Tricks
Beyond the algorithms themselves, OpenRLHF ships flags for a stack of tricks from recent papers.
4.4.1 DAPO overlong penalty (length control in the reward)
DAPO adds a soft length penalty to reward-shape away very long responses:
expected_len = max_new_tokens - overlong_buffer_len
if valid_response_length > expected_len:
exceed_len = min(valid_response_length - expected_len, overlong_buffer_len)
penalty = -exceed_len / overlong_buffer_len * overlong_penalty_factor
rewards[i] += penalty
With max_new_tokens=2048, overlong_buffer_len=512: responses up to 1536 tokens get no penalty; between 1536 and 2048 the penalty ramps linearly from 0 to -penalty_factor. This is smoother than a hard cutoff and it nudges the model to finish its thought inside the budget.
Flags: --reward.overlong_buffer_len, --reward.overlong_penalty_factor.
4.4.2 ProRL stop-properly penalty (truncation penalty)
A complementary trick from ProRL, using vLLM’s finish_reason == "length" signal to detect truncated samples:
if coef >= 0: # multiplicative: scale truncated rewards by coef in [0, 1]
rewards[j] = rewards[j] * coef
else: # absolute: override to coef (e.g. -0.5)
rewards[j] = coef
The multiplicative form is a “zero out if truncated” (coef=0) or “discount heavily” (coef=0.1). The absolute form gives truncated samples a fixed negative reward (coef=-0.5). Both say: “if you couldn’t finish in time, your reward doesn’t count” — which prevents the policy from learning to always generate the maximum length.
Flag: --reward.stop_properly_penalty_coef.
4.4.3 GSPO — sequence-level importance sampling
GSPO replaces the per-token PPO ratio with a per-sequence ratio:
\[\rho_{\text{seq}} = \exp\!\left(\frac{1}{|y|} \sum_t (\log \pi_\theta(y_t) - \log \pi_{\text{old}}(y_t))\right)\]The clipped surrogate is then min(ρ_seq · A, clip(ρ_seq) · A) applied uniformly to every token in the sequence — instead of each token having its own ρ_t. From loss.py:
if policy_loss_type == "gspo":
log_ratio = log_probs - old_log_probs # (B, T)
ratio = (log_ratio * action_mask).sum(dim=-1) / action_mask.sum(dim=-1) # (B,)
ratio = ratio.exp().unsqueeze(-1) * action_mask # (B, T), same value repeated
Why: token-level ratios can be noisy because π_old and π_θ may agree on some tokens and disagree on others — averaging across tokens smooths the signal. GSPO also requires sequence-level loss aggregation (not token-level), which is why the code sets self.token_level_loss = False automatically when policy_loss_type == "gspo".
Flag: --actor.policy_loss_type gspo.
4.4.4 vLLM IS correction: TIS, ICEPOP, seq-mask-TIS
When rollouts come from vLLM but loss is computed on the HF model, the two forward passes aren’t guaranteed bit-identical — numerical drift, different kernel choices, and different KV cache policies can create a sampling distribution mismatch. Fengyao’s off-policy RL writeup formalized this: vLLM is effectively a slightly-different π_rollout from the HF π_old, so training with rollouts drawn from π_rollout is technically off-policy with respect to π_old.
The fix is an extra importance-sampling correction w_t = π_old(y_t) / π_rollout(y_t), applied to the loss:
- TIS (Token-level IS,
tis): clampw_tto[low, high], apply per-token. Gentle. - ICEPOP (token filter,
icepop): setw_t = 0for tokens outside[low, high]— those tokens contribute nothing. Aggressive. - seq-mask-TIS (sequence-level filter + token correction,
seq-mask-tis): compute the sequence-level geometric mean ofw_t; if it’s outside[low, high], zero out the whole sequence; otherwise apply token-level TIS within the kept sequences. A mix.
Flags: --algo.advantage.is_correction_enable, --algo.advantage.is_correction_type {tis,icepop,seq-mask-tis}, --algo.advantage.is_correction_threshold 0.5 5.0.
This kind of correction becomes essential under async training (generation and training run concurrently) and partial rollout (vLLM pause/resume across weight updates), where the off-policy-ness is deliberate, not accidental.
4.4.5 Dual-clip PPO
Standard PPO clipping bounds the ratio on the upper side for positive advantages, but the lower side is unbounded when advantages are negative — a single rollout with a very negative advantage and very small ratio can produce a huge, destabilizing gradient. Dual-clip PPO adds a second clip on the negative branch:
\[L_{\text{dual-clip}} = -\max\!\big(\min(\rho A, \text{clip}(\rho) A),\; c \cdot A\big), \quad A < 0,\; c > 1\]From loss.py:
clip1 = torch.min(surr1, surr2) # standard PPO clip
clip2 = torch.max(clip1, self.dual_clip * advantages) # additional lower bound
loss = -torch.where(advantages < 0, clip2, clip1)
4.4.6 DAPO dynamic filtering
DAPO’s other big contribution is filtering out rollout groups with degenerate rewards before backward — if all G rollouts of a prompt got the same reward (all succeeded or all failed), the group mean is saturated and the advantages collapse to zero. Including these in the backward pass just adds noise.
OpenRLHF implements this as score-based filtering:
--algo.dynamic_filtering_enable
--algo.dynamic_filtering_range 0.0 1.0
Only rollout groups whose mean score falls inside [0.0, 1.0] (configurable) are kept. In practice this keeps the “still-learning” prompts and drops the “solved” and “hopeless” ones. Requires --rollout.n_samples_per_prompt > 1 (need a group to filter).
This composes with oversampling: set --rollout.vllm_generate_batch_size > --rollout.batch_size and generate a surplus, then filter down to the desired training batch.
4.4.7 No-std-norm and reference-free training
Two smaller knobs:
--algo.advantage.no_std_normdisables the global whitening’s division by std, keeping only mean subtraction. Useful when reward magnitudes are already well-calibrated and you don’t want the extra rescaling.--algo.kl.init_coef 0trains without a reference model at all — no KL penalty, noπ_refin memory. Matches DeepSeek-R1-Zero’s setup. Saves substantial VRAM. The policy relies entirely on the reward and the PPO-clip trust region for stability.
4.5 The Key Takeaways
If you only remember a few things about OpenRLHF:
-
Everything critic-free is the same algorithm. PPO-clip + KL-in-reward + multi-epoch, differing only in how the per-group reward is shaped before the cumulative sum. Dr. GRPO and REINFORCE++-baseline are one line apart.
-
RLOO is unbiased because of leave-one-out.
(ΣR − R_i)/(G−1)doesn’t include the sample’s own reward in its own baseline — that’s what makes the estimator unbiased, and it’s an older statistics trick (jackknife) applied to RL. -
Dr. GRPO removes GRPO’s
/stdto fix difficulty bias. When a prompt’s rollouts all succeed or all fail, GRPO’s advantages blow up; Dr. GRPO (“GRPO Done Right”) just subtracts the mean. In OpenRLHF’s code: oneelifbranch. -
Multi-turn is about token-level trajectories, not text. The
reset/stepAPI produces a single token sequence with markedaction_ranges; the trained loss is applied only over actions, and the final reward is spread across all action tokens. Rewards are code, not a dataset column. -
Length control is a reward-shaping problem, not a model-config problem. DAPO’s overlong penalty + ProRL’s stop-properly penalty are both applied to the scalar reward before advantage computation — no changes to the loss math, just to what
Rmeans. -
Off-policy RL is unavoidable at scale. Async training, partial rollout, and vLLM/HF drift all make the sampling distribution slightly off from
π_old. TIS/ICEPOP/seq-mask-TIS are three flavors of IS correction that keep the training valid despite this.
Part 5: Reward Modeling — Scale, Overconfidence, Calibration
The reward model (RM) is the unsung weak link of RLHF: every other piece can be perfect, but if the RM is mis-scaled or overconfident, the policy will happily exploit it. This section is about what makes the RM hard — the scale-invariance pathology of Bradley-Terry training, why output ranges drift unpredictably, and the calibration tricks that have emerged to keep BT-trained RMs honest.
5.1 The Bradley-Terry Setup, Recap
Architecture (Part 4.4): a transformer backbone with a single Linear(H, 1) scoring head, trained on preference pairs $(x, y_c, y_r)$ via the contrastive log-loss
The loss depends only on the difference $r_c - r_r$. Adding a constant to every reward is invisible to the loss — the absolute scale and offset are mathematically arbitrary.
This shift-invariance is the root cause of every other issue in this section.
5.2 The Output-Range Problem
Because the scale is unconstrained, RM outputs drift wherever optimization takes them. In practice this matters because PPO’s policy objective is
\[\mathbb{E}[r_\theta(x,y)] - \beta \cdot \text{KL}(\pi \,\|\, \pi_{\text{ref}}),\]with a fixed coefficient $\beta$. If r_θ ends up centered near 0 with stdev 1, β = 0.01 is reasonable. If r_θ ends up centered near +50 with stdev 20, the same β is effectively zero — the KL constraint vanishes and the policy reward-hacks. If β is then bumped up to compensate, it over-suppresses the policy.
Three layers of practical defenses:
-
Mean-centering during RM training. Add an auxiliary loss penalizing $ \mathbb{E}x[r\theta(x, y)] $ over a reference dataset, so the model is pushed toward a near-zero mean by construction. Doesn’t bound the variance, but anchors the offset. -
Post-training whitening. Before passing rewards to PPO, normalize over a reference batch (often the SFT distribution):
r ← (r − μ) / σ. This is what OpenRLHF’s--reward.normalize_enabledoes — it tracks running mean/std asregister_buffers and applies them at inference. Result: rewards typically fall in[-2, +2]. - Reward clipping. Hard
[-c, +c]cap to bound the worst-case gradient magnitude. OpenRLHF’s--reward.clip_rangeflag.
The pattern is: leave the RM unbounded; stabilize downstream. A bounded-output head (e.g. tanh(score) · scale) is uncommon — practitioners prefer post-hoc normalization because it preserves the freedom of the BT loss to find whatever scale converges fastest.
5.3 The Overconfidence Problem
The BT loss has a deeper pathology than scale drift: it has no penalty for being too confident. Once the gap $\Delta = r_c - r_r$ is large enough that $\sigma(\Delta) \approx 1$, the gradient gets small but never zero — so optimization keeps pushing $\Delta$ wider. The result is a model that confidently predicts a 99% preference probability for pairs where actual human agreement is closer to 60%.
This is the mechanism behind reward hacking:
- BT-trained RMs become systematically overconfident on out-of-distribution responses.
- During PPO, the policy explores OOD as it improves — verbose answers, repetitive formatting, sycophantic phrasing.
- The RM erroneously assigns these high scores with high confidence.
- The policy chases those scores. Reward goes up; quality goes down.
A more theoretical framing: BT-trained RMs effectively rank by Δ, but they don’t know what Δ means in terms of human preference probability. Calibration — making σ(Δ) actually correspond to the empirical chance of preference — is the missing piece.
5.4 Calibration Techniques
Four techniques worth knowing, roughly in order of how often they’re used in practice:
-
Temperature scaling. The simplest post-hoc fix: learn a single scalar
Ton a held-out preference set such that $\sigma((r_c - r_r) / T)$ matches empirical preference frequencies. During RL, divide rewards byT. Preserves the ranking; softens overconfident extremes. Cheap and surprisingly effective. -
Margin-aware BT loss. Add a per-pair margin term:
\[L = -\log \sigma(r_c - r_r - m(y_c, y_r)).\]When fine-grained ratings are available (Likert scales, judge confidence), set
mproportional to the rating gap. The model gets a “good enough” stopping point for each pair instead of pushing every gap to infinity. -
Batch-wise Sum-to-Zero Regularization (BSR). Force $\sum_i r_\theta(x_i, y_i) = 0$ within each training batch. Prevents the hidden-state norms of “chosen” and “rejected” pools from drifting apart unboundedly — the mechanism behind dispersion-driven overconfidence. Improves OOD robustness during downstream RLHF.
-
Probabilistic Uncertain Reward Models (PURM). Replace the scalar reward with a distribution — the head outputs
(μ, σ)per response, and the loss is the BT probability marginalized over Gaussian-distributed rewards. The model learns to widenσon inconsistent or OOD inputs. Penalizes overconfident predictions on the very inputs where a policy is most likely to reward-hack. Proper uncertainty quantification rather than a post-hoc patch.
5.5 GRPO Mostly Sidesteps This
Worth flagging: GRPO and the RLVR family (verifiable rewards) sidestep most of this. Two reasons:
-
Verifiable rewards are calibrated by definition. A binary 0/1 reward from
is_correct(answer, ground_truth)doesn’t need calibration — it has no scale to drift, no overconfidence to worry about. This is why RLVR has emerged as such a clean signal for math, code, and other verifier-friendly domains. -
Group-level normalization absorbs scale uncertainty. Even when GRPO uses a neural RM, the per-prompt $(R_i - \bar R) / \sigma_R$ standardization (Block A4 from §4.2) makes the absolute scale of the RM’s output irrelevant within each group. Calibration still matters for cross-prompt comparison, but the worst pathologies are damped.
This is a big part of why DeepSeek-R1’s GRPO + verifier approach took off — it removes the most fragile component of classical RLHF (the BT-trained RM) and replaces it with deterministic ground-truth checks. The cost is task scope (only works where you can write a verifier), but where it works, it works much better.
5.6 The Practical Takeaway
For PPO with a BT-trained RM, the production stack is roughly: train with mean-centering → normalize/whiten outputs → clip rewards → tune β carefully → maybe ensemble multiple RMs. None of these individually solves overconfidence; together they keep it bounded long enough for the policy to learn something useful before reward hacking dominates.
For tasks where verifiable rewards are available, prefer them. A binary correctness check from a unit test is worth ten BT-trained RMs.
The deeper observation: a BT-trained RM is a learned approximation of human preference, with all the calibration and OOD-generalization issues that any learned model has. The RL algorithm treats it as if it were ground-truth signal — and that gap is where reward hacking lives. Every calibration technique above is fundamentally about closing that gap.
Part 6: verl — Production-Scale RL Systems
Where TRL is the readable reference and OpenRLHF is the research toolkit, verl is the production-scale system: ByteDance Seed’s open-source implementation of the HybridFlow paper, the framework that now trains DeepSeek-V3 671B and Qwen3-235B with Megatron + vLLM + Ray. The interesting parts of verl aren’t the algorithms (it implements the same PPO/GRPO/RLOO/DAPO menu) — they’re the systems infrastructure that makes those algorithms actually run on hundreds of GPUs without melting. This section is about that infrastructure.
The story has six pieces, each solving a concrete problem at scale:
- HybridFlow controller — separating control flow (RL algorithm) from computation flow (model engines), so you can swap FSDP↔Megatron and vLLM↔SGLang without rewriting the trainer.
- Hybrid Engine + weight resharding — moving weights between training layout (FSDP / Megatron’s TP-PP-EP) and inference layout (vLLM’s TP) every step, fast.
- Checkpoint Engine — the unified weight-sync abstraction (NCCL / HCCL / NIXL / Mooncake / Kimi) that powers cross-node weight transfer.
- Async training — parallelizing rollout and training to recover the 70%+ of wall-clock time spent on long-tail samples.
- AgentLoop — multi-turn / tool-use RL with token-in-token-out trajectories.
- Transfer Queue — distributed data pool for streaming rollout → train pipelines.
The Q3 2025 roadmap (#2388) is organized around exactly these pillars: composable model engines, modular rollout workers, async/disaggregated architecture, multi-turn/data infra. Reading verl’s code is reading that roadmap as it gets executed.
6.1 The HybridFlow Controller — Separating Control from Computation
The motivation (from the HybridFlow paper and docs/hybrid_flow.rst): an RL training job has two levels of dataflow:
- Control flow — high-level operators (rollout → reward → advantage → train), one operator per step. This is the RL algorithm.
- Computation flow — model forward/backward/optimizer, parallelized across many GPUs. This is the model engine.
In classical RL the control flow is small (scalar arithmetic) so you embed it inside computation. In LLM RL the computation itself is multi-process (FSDP, Megatron, vLLM), so two design choices appear:
- Unified multi-controller (everything multi-process). Best raw performance, but tight coupling — the PPO loop is bound to one specific computation backend.
- Hybrid controller — single-process control, multi-process computation. Looser coupling, slightly more communication overhead. This is what verl picks.
The single-process controller (the RayPPOTrainer “driver”) writes the algorithm as plain Python. The computation lives in WorkerGroups scheduled by Ray, each pinned to a resource pool of GPUs. The driver calls worker methods over RPC; each call dispatches data across DP ranks, runs the work, and collects results back.
The syntax sugar is the @register decorator on verl.single_controller.base.decorator.Dispatch:
class ActorRolloutRefWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, batch: DataProto) -> DataProto:
# runs on each DP rank; framework auto-splits/auto-collects
...
The dispatch_mode (defined at verl/single_controller/base/decorator.py) tells the framework how to slice the input across DP ranks and how to collect outputs — DP_COMPUTE_PROTO is “split along the batch dim, run, gather”. The driver-side call
output: DataProto = actor_rollout_ref_wg.generate_sequences(batch)
is one line, but the framework does the split-dispatch-collect under the hood. You can rewrite the algorithm without touching the engine, or swap the engine without touching the algorithm. This is what makes verl a “modular foundational library” (Q3 roadmap’s first bullet).
Three worker groups are typical for PPO: ActorRolloutRef (training engine + inference engine + frozen ref), Critic, Reward. Each can be placed on its own GPU set, or colocated for fast NCCL weight transfer (more on that next).
6.2 Hybrid Engine — Per-Tensor Weight Resharding
The most-cited verl optimization is weight resharding between training and inference layouts. The problem:
- Training uses FSDP2 or Megatron — each parameter is sharded along data-parallel and possibly tensor-parallel / pipeline-parallel / expert-parallel axes.
- vLLM (or SGLang) uses its own tensor-parallel layout, designed for inference throughput.
Every PPO step, you finish a training update, then need to push the new weights into the rollout engine before generation. Naive copy-via-CPU moves the full model through host memory and would dominate step time on a 671B model.
verl’s solution is per-tensor IPC + bucketed transfer + streaming weight conversion. The relevant code lives in verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py. The BucketedWeightSender packs small training-layout tensors into a fixed-size CUDA IPC buffer, sends a metadata record over ZMQ, then the receiver (vLLM worker) pulls them out into its own layout. Two send buffers ping-pong while one is being filled and the other transferred — overlapping computation and communication.
For sharded layouts, the conversion happens per tensor on the fly via Megatron-bridge weight converters (verl/models/mcore/weight_converter.py — referenced as “per tensor weight resharding” in the DeepSeek 671B doc). At 671B parameters this is what makes the difference between “viable at 96 H20 GPUs” and “OOM”.
6.3 Checkpoint Engine — A Unified Weight Sync Abstraction
When training and rollout live on separate GPU sets (the disaggregated mode used by async training), Hybrid Engine’s IPC trick doesn’t apply — you need network communication. verl’s verl/checkpoint_engine/ provides a unified API:
class CheckpointEngine(ABC):
def send_weights(self, weights: Generator[(name, tensor)]): ...
def receive_weights(self) -> Generator[(name, tensor)]: ...
def get_weights(self) -> Generator[(name, tensor)]: ...
The same three methods are implemented by six backends, each making a different trade-off (from verl/checkpoint_engine/README.md):
| Backend | Comm primitive | Topology | Hardware | Elastic | Use case |
|---|---|---|---|---|---|
| naive | torch.distributed | all_gather | NV/AMD/Ascend | — | colocated on-policy |
| nccl | NCCL | all_gather + broadcast | NV GPU | low | disaggregated, fixed cluster |
| hccl | HCCL | all_gather + broadcast | Ascend NPU | low | same as NCCL but Ascend |
| nixl | NIXL | all_gather + ring p2p | UCX/UCCL/Mooncake | high | dynamic/elastic rollout |
| kimi_ckpt_engine | Mooncake + NCCL/HCCL | p2p + broadcast | NV/Ascend | low | save-checkpoint each sync |
| mooncake | Mooncake transfer engine | all_gather + ring p2p | NV/Ascend | high | fixed cluster, MoE |
The NCCL backend is the workhorse. Reported single-step sync times from the fully-async-policy README:
| Model | Trainer ranks | Rollout ranks | Without ckpt-engine | With ckpt-engine |
|---|---|---|---|---|
| Qwen2.5-Math-7B | 4 | 4 | 0.12s | 0.02s |
| Qwen3-30B-A3B | 16 | 16 | 15.76s | 4.38s |
| Qwen3-235B-A22B | 64 | 64 | 58.57s | 23.70s |
A 235B sync drops from a minute to 24 seconds. At sync frequency of every few mini-steps, that’s the difference between async training being a win and a wash.
6.3.1 NCCL Backend — Detailed Walkthrough
A natural mental model is: “trainer-0 gathers sharded weights, then there’s one CUDA process group between trainer-0 and rollout, the bucket goes out, and each rollout rank picks the slice it needs”. That’s almost right — with two important corrections after reading verl/checkpoint_engine/nccl_checkpoint_engine.py:
Correction 1 — all trainer ranks call send_weights, but only rank 0 talks on NCCL. The build_topology assigns trainer rank 0 to NCCL rank 0, and all other trainer ranks to sentinel -1:
# NCCLCheckpointEngine.build_topology
trainer_kwargs = {
"rank": [0] + [-1] * (trainer_world_size - 1), # only trainer-0 in the NCCL group
"world_size": [rollout_world_size + 1] * trainer_world_size,
}
rollout_kwargs = {
"rank": list(range(1, rollout_world_size + 1)), # rollout ranks 1..N
"world_size": [rollout_world_size + 1] * rollout_world_size,
}
So the NCCL group has size N + 1: trainer-0 plus all rollout ranks. The other trainer ranks are not in this group at all. But they still call send_weights — because the weight generator they iterate (engine.get_per_tensor_param()) does the gather as a side effect of yielding each tensor, and gather is a collective across all trainer ranks. Iterating the generator on rank -1 keeps those gather collectives in sync; the rank--1 branch then drops the data on the floor:
async def send_weights(self, weights):
assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights."
if self.rank < 0:
for name, weight in weights: # drain — keeps the model engine's gathers in lockstep
pass
return
# ... rank-0-only code below: bucketize and broadcast ...
Correction 2 — rollout ranks do not pick a shard. The NCCL primitive used is collective.broadcast(bucket, src_rank=0), not scatter. Every rollout rank receives the identical full-tensor bucket. The actual TP slicing happens one layer downstream, inside vLLM, after the checkpoint engine has handed it the full tensor over CUDA IPC.
The architecture diagram from checkpoint_engine/base.py makes this concrete:
trainer (model engine + ckpt engine in same process) rollout (separate ckpt-engine process per GPU)
┌────────┬────────┬─────┬────────┐ ┌───────────────────┬───────────────────┐
│ ┌────┐ │ ┌────┐ │ │ ┌────┐ │ │ Replica 0 │ Replica 1 │
│ │ ME0│ │ │ ME1│ │ │ │ MEn│ │ ├────┬────┬────┬────┼────┬────┬────┬────┤
│ └──┬─┘ │ └────┘ │ ... │ └────┘ │ │ 0 │ 1 │ 2 │ 3 │ 0 │ 1 │ 2 │ 3 │ ← vLLM WorkerProcs
│ ▼ │ │ │ │ └──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┘
│ ┌──┴─┐ │ ┌────┐ │ │ ┌────┐ │ ▲ ▲ ▲ cuda ipc ▲ ▲ ▲
│ │ CE │ │ │ CE │ │ │ │ CE │ │ ┌──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┐
│ └──┬─┘ │ └────┘ │ │ └────┘ │ │ CE │ CE │ CE │ CE │ CE │ CE │ CE │ CE │ ← CheckpointEngineWorkers
└────┼───┴────────┴─────┴────────┘ └──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┘
▼ │ │ │ │ │ │ │ │
└────────────── NCCL broadcast ──────────┴────┴────┴────┴────┴────┴────┴────┘
rank 0 ranks 1..N
So there are really three transports, each at a different level:
- Model-engine-internal gather across trainer ranks — turns sharded weights into full tensors. Happens inside
engine.get_per_tensor_param(). - NCCL broadcast from trainer-0 to every rollout rank — ships full tensors as raw bytes. Same content reaches every rank.
- CUDA IPC from each rollout rank’s
CheckpointEngineWorkerto its colocated vLLMWorkerProcon the same GPU — vLLM then re-shards the full tensor along its own TP layout viaupdate_weights.
6.3.2 The NCCL Pipeline — Pseudocode
Three pieces: setup, send-loop, receive-loop. Names match the actual code (NCCLCheckpointEngine, BroadcastOperation).
Setup — runs once per training run, or per-sync if rebuild_group=True (elastic mode):
# CheckpointEngineManager.build_process_group()
# 1. Each worker allocates two CUDA buffers (send_buf, recv_buf), each `bucket_size` bytes.
# Trainer-0 also starts a ZMQ PUB server and returns its (ip, port).
metadata = ray.get(
trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) +
rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size)
)
# 2. Build communication topology:
# NCCL group = { trainer-0 (rank 0) } ∪ { rollout ranks 1..N }
trainer_kwargs, rollout_kwargs = NCCLCheckpointEngine.build_topology(
trainer_world_size, rollout_world_size, metadata
)
# trainer_kwargs["rank"] = [0, -1, -1, ..., -1] ← only trainer-0 in the NCCL group
# rollout_kwargs["rank"] = [1, 2, 3, ..., N]
# both have world_size = N+1, master_metadata = trainer-0's (zmq_ip, zmq_port)
# 3. Each worker calls init_process_group(rank, world_size, master_metadata):
# - rank == 0: bind ZMQ PUB on (ip, port); init_collective_group(N+1, 0, "nccl")
# - rank > 0: connect ZMQ SUB to trainer-0; init_collective_group(N+1, rank, "nccl")
# - rank < 0: skip — sentinel for non-leader trainer workers
collective.barrier(group_name)
Trainer-0 sending (send_weights):
async def send_weights(self, weights): # weights = engine.get_per_tensor_param()
if self.rank < 0: # non-leader trainer ranks
for _ in weights: pass # drain — keep gather collectives alive
return
send_buf, recv_buf = self.send_buf, self.recv_buf # two ping-pong buffers
broadcast_op = None
bucket_meta = {}
offset = 0
for name, weight in weights: # full tensors, gathered by ME on the fly
if offset + weight.nbytes > bucket_size:
# Bucket is full → kick off broadcast of current bucket.
torch.cuda.synchronize()
if broadcast_op is not None:
await broadcast_op.wait_for_complete() # wait for previous broadcast to land
broadcast_op = BroadcastOperation(
rank=0,
bucket=send_buf, # broadcast THIS buffer
metadata={"bucket_meta": bucket_meta, "is_last": False},
socket=self.socket, # ZMQ PUB
topic="bucket_metadata",
)
# Inside BroadcastOperation._run() (runs in a separate thread):
# 1. socket.send_string(topic); socket.send_pyobj(metadata)
# 2. collective.broadcast(send_buf, src_rank=0, group_name)
# Swap buffers — start filling the OTHER one while NCCL is broadcasting.
send_buf, recv_buf = recv_buf, send_buf
bucket_meta, offset = {}, 0
# Pack this weight into send_buf at `offset`, record its (offset, shape, dtype).
bucket_meta[name] = TensorMeta(name=name, shape=weight.shape,
dtype=weight.dtype, offset=offset)
send_buf[offset : offset + weight.nbytes] = cp.asarray(weight.view(-1).view(uint8))
offset += weight.nbytes
# Final bucket — set is_last=True so receivers know to stop after.
torch.cuda.synchronize()
if broadcast_op is not None:
await broadcast_op.wait_for_complete()
broadcast_op = BroadcastOperation(
rank=0, bucket=send_buf,
metadata={"bucket_meta": bucket_meta, "is_last": True},
socket=self.socket, topic="bucket_metadata",
)
await broadcast_op.wait_for_complete()
Rollout rank receiving (receive_weights — async generator):
async def receive_weights(self):
send_buf, recv_buf = self.send_buf, self.recv_buf
# First bucket — receive into recv_buf and BLOCK on it (no previous to overlap with).
op = BroadcastOperation(
rank=self.rank, bucket=recv_buf, metadata=None,
socket=self.socket, topic="bucket_metadata",
)
# Inside _run():
# 1. socket.recv_string() → socket.recv_pyobj() → metadata
# 2. collective.broadcast(recv_buf, src_rank=0, group_name) ← writes into recv_buf
metadata = await op.wait_for_complete()
# Pipelined loop: yield from filled buffer while next broadcast lands in the other.
send_buf, recv_buf = recv_buf, send_buf # filled buffer is now `send_buf`
while not metadata["is_last"]:
# Kick off receiving the NEXT bucket into the (currently empty) recv_buf.
op = BroadcastOperation(rank=self.rank, bucket=recv_buf, metadata=None, ...)
# Meanwhile, parse out tensors from the bucket we already have.
for name, meta in metadata["bucket_meta"].items():
size = meta["dtype"].itemsize * meta["shape"].numel()
tensor = send_buf[meta["offset"] : meta["offset"] + size] \
.view(dtype=meta["dtype"]).view(meta["shape"])
yield name, tensor # ← consumer (vLLM via CUDA IPC) pulls this
# Wait for the next bucket to land, then swap.
metadata = await op.wait_for_complete()
torch.cuda.synchronize()
send_buf, recv_buf = recv_buf, send_buf
# Last bucket — drain it.
for name, meta in metadata["bucket_meta"].items():
...
yield name, tensor
The upstream call sites that wire everything together:
# verl/workers/engine_workers.py — trainer-side weight push
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self, global_steps=None):
if self.config.rollout.checkpoint_engine.backend != "naive":
per_tensor_param, _ = self.actor.engine.get_per_tensor_param() # ← ME gathers shards
await self.checkpoint_engine.send_weights(per_tensor_param) # ← CE bucketed broadcast
return
# ... naive (colocated) path: just IPC to local vLLM ...
# verl/checkpoint_engine/base.py — rollout-side weight pull
class CheckpointEngineWorker(Worker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self, global_steps=None):
weights = self.checkpoint_engine.receive_weights() # ← async generator
await self.server_adapter.update_weights(weights, ...) # ← cuda ipc to vLLM
6.3.3 Four Design Points to Remember
-
NCCL group has size
rollout_world_size + 1— trainer-0 plus all rollout ranks. Other trainer ranks aren’t in this group; they’re at sentinel rank-1and only iterate the weight generator to drive the model-engine’s gather collectives. -
Two-buffer ping-pong for pipelined broadcast. While NCCL broadcasts bucket K (in a thread, via
loop.run_in_executor), the producer fills bucket K+1.wait_for_complete()is the join point. -
Out-of-band metadata over ZMQ PUB/SUB carries the
(name, shape, dtype, offset)records. NCCL only ships the raw bytes — small enough to fit. Decoupling metadata from data means rollout ranks don’t need to know the layout in advance, and receivers can unpack the bucket structure dynamically. -
No sharding at the checkpoint-engine layer. Each rollout rank receives the same full tensor; vLLM does its own TP slicing inside
update_weights. This is what keeps the checkpoint engine generic across vLLM/SGLang/etc — engine-specific layout knowledge stays in the engine.
The rebuild_group flag lets you destroy and recreate the NCCL communicator each sync — needed for elastic mode where the rollout worker set changes between syncs.
6.4 Async Training — Recovering the Long Tail
In synchronous colocated PPO, rollout dominates wall-clock time. The fully_async_policy README quotes ~70% of total time in DAPO 32B going to rollout, and the one-step-off doc shows the same pattern. The reason is the long tail: for a batch of prompts, total rollout time is bounded below by the slowest generation in the batch. A few prompts that hit max_new_tokens block everything else. Adding more rollout GPUs doesn’t help — the slowest sample still has to finish.
verl’s async stack tackles this through three increasingly aggressive modes, all in verl/experimental/:
6.4.1 One-Step-Off Async (one_step_off_policy/)
Simplest version: while the trainer updates with batch N, the rollout generates batch N+1. The trainer then trains on N+1 once it arrives.
# from one_step_off doc
batch_data_future = self._async_gen_next_batch(continuous_iterator) # launch first
while batch_data_future is not None:
batch = batch_data_future.get() # wait for previous
batch_data_future = self._async_gen_next_batch(...) # launch next async
# ...train on batch...
The policy is always exactly 1 step off (samples were generated by the previous policy). Reported gains: +23–40% throughput on DAPO 32B (FSDP2 / Megatron).
6.4.2 Fully-Async Policy (fully_async_policy/)
Generalizes to N steps off and streaming rollout. Four components (from fully_async_policy/README.md):
┌────────────────────────┐
│ Rollouter │ generates samples one at a time
│ (FullyAsyncRollouter) │ controlled by staleness
└──────────┬─────────────┘
│ put_sample()
▼
┌────────────────────────┐
│ MessageQueue │ ray actor, async queue
└──────────┬─────────────┘
│ get_sample()
▼
┌────────────────────────┐
│ Trainer │ consumes samples;
│ (FullyAsyncTrainer) │ every trigger_parameter_sync_step
└──────────┬─────────────┘ triggers a weight sync
│
▼
┌────────────────────────┐
│ ParameterSynchronizer │ NCCL via checkpoint_engine
└────────────────────────┘
Three knobs control how aggressive the async-ness is:
trigger_parameter_sync_step— how many local trainer updates before pushing weights to rollout. Larger = more off-policy, less sync overhead.staleness_threshold— max ratio of “stale” (old-policy) samples allowed in training.0= synchronous;>0= async.partial_rollout— when triggering a weight sync, interrupt in-flight rollouts (vLLMsleep/resume), sync, then resume. Avoids waiting for the longest sample.
These produce four operating modes:
| Mode | trigger_sync | staleness | partial | Behavior |
|---|---|---|---|---|
| a on-policy | 1 | 0 | — | strictly on-policy, has gen idle bubble |
| b stream off-policy | >1 | 0 | — | streaming, but waits for last batch |
| c async + stale | ≥1 | >0 | False | overlap, but waits for in-flight at sync |
| d async + partial | ≥1 | >0 | True | interrupt + resume on sync — most aggressive |
Reported result from the README: 2.35–2.67× speedup on Qwen2.5-Math-7B / 128 GPUs, with comparable accuracy. The biggest gain comes from mode (d) with partial_rollout=True and staleness_threshold=0.5 — the long tail is eliminated because the rollout never blocks the next sync.
The trainer’s FullyAsyncTrainer.fit_step looks recognizably like PPO but with async ergonomics:
async def fit_step(self):
batch = await self._fit_generate(None) # _get_samples_from_queue()
batch = self._fit_compute_reward(batch)
batch = self._fit_compute_log_prob(batch) # uses rollout_log_probs by default
batch = self._fit_compute_ref_log_prob(batch)
batch = self._fit_compute_advantage(batch)
batch = self._fit_update_critic(batch)
batch = self._fit_update_actor(batch)
self._fit_update_local_step()
await self._fit_update_weights() # only every trigger_parameter_sync_step
Note use_rollout_log_probs=True by default — the importance-ratio denominator is the rollout-time logprob, not a separately-recomputed π_old from the training engine. This matters for off-policy correctness (Part 4.4.4): when samples come from an old policy version, you need to know exactly which logprob they were sampled with, and the rollout reports that directly.
The algorithm.rollout_correction.bypass_mode=False flag (sketch from _compute_old_log_prob) lets you re-derive old_log_prob from a saved CPU snapshot of the version-1 weights — restoring the trainer to those weights, computing logprobs, then restoring back. This is what AReaL calls “Decoupled PPO”. Expensive but exact.
6.5 AgentLoop — Multi-Turn and SWE-Style Tool RL
For tool-using agents (web search, code execution, sandbox runs — the SWE-bench style of RL), verl’s abstraction is the AgentLoop. The architecture lives in verl/experimental/agent_loop/:
PPOTrainer.generate_sequences()
│
▼
AgentLoopManager.generate_sequences()
│ wake_up servers, sync weights
│ split batch into chunks
▼
AgentLoopWorker (Ray actor) ← one per chunk; runs many agent loops concurrently
│ for each prompt:
│ spawn AgentLoopBase.run() ← user code (e.g. ToolAgentLoop)
▼
AsyncLLMServerManager ← request-level load balancer + sticky sessions
│ route to least-loaded server
▼
AsyncLLMServer (vLLM/SGLang) ← per-DP-group server, generate(prompt_ids) → response_ids
The user-defined entry point is AgentLoopBase.run:
class AgentLoopBase(ABC):
@abstractmethod
async def run(self, sampling_params, **kwargs) -> AgentLoopOutput:
...
class AgentLoopOutput(BaseModel):
prompt_ids: list[int]
response_ids: list[int] # LLM-generated + tool response token ids
response_mask: list[int] # 1 = LLM-generated, 0 = tool response ← loss mask
The output is what feeds into PPO: response_ids is the entire post-prompt token sequence, response_mask marks which slices were the model’s actions vs. environment feedback. Same token-in-token-out invariant as OpenRLHF (Part 4.3) — and for the same reason: re-tokenizing chat-formatted text round-trips badly enough that PPO won’t converge.
For tool-using agents, the ready-made implementation is ToolAgentLoop, organized as a state machine:
class AgentState(Enum):
PENDING # apply chat template, prepare prompt
GENERATING # call AsyncLLMServerManager.generate
PROCESSING_TOOLS # parse tool calls, execute in parallel, append responses
TERMINATED # max turns / max length / no more tool calls
while state != AgentState.TERMINATED:
if state == AgentState.PENDING: state = await self._handle_pending_state(...)
elif state == AgentState.GENERATING: state = await self._handle_generating_state(...)
elif state == AgentState.PROCESSING_TOOLS: state = await self._handle_processing_tools_state(...)
Termination conditions: response length ≥ response_length, assistant turns ≥ max_assistant_turns, user turns ≥ max_user_turns, or no tool calls in last response. Tools are dispatched in parallel via asyncio.gather, with max_parallel_calls capping concurrency per turn.
Why a separate AsyncLLMServer?
The agent loop needs the LLM as a service: many concurrent multi-turn conversations, each making generate calls at different rates. A single colocated vLLM engine would serialize them. So verl exposes vLLM/SGLang as a server (one per DP group of the inference engine), with a AsyncLLMServerManager in front doing:
- Least-in-flight load balancing — first-turn requests go to the least-busy server.
- Sticky sessions — multi-turn conversations route to the same server via
request_id, so vLLM’s prefix cache hits on the second turn. - Token-in-token-out API —
generate(prompt_ids: list[int], sampling_params)returnsresponse_ids: list[int], never round-tripping through text.
The GlobalRequestLoadBalancer Ray actor sees all in-flight counts globally. Multi-turn calls from the same conversation ID always hit the same server (because the first call cached the prefix there), and new conversations go to whichever server has fewest in-flight.
This is what makes “multi-turn tool-calling RL with vLLM” actually work at scale. The recipe in recipe/retool/ (now in the verl-recipe submodule) builds on ToolAgentLoop with a code sandbox, and the fully_async_policy recipe extends it with AsyncPartialToolAgentLoop that supports partial_rollout interruption mid-conversation.
6.6 The Full Picture: Fully-Async Multi-Turn Tool-Use Architecture
So far we’ve covered the pieces in isolation: HybridFlow controller (6.1), weight resharding (6.2), checkpoint engine (6.3), async training (6.4), and AgentLoop (6.5). Putting them together gives the production architecture for industrial-grade agentic RL — fully-async + multi-turn + tool-use, the configuration that trains things like SWE agents, coding assistants with code sandboxes, and reasoning agents with retrieval. This is what recipe/fully_async_policy + recipe/retool + ToolAgentLoop looks like, end to end.
Below is the topology diagram. Every named box is an actual class or Ray actor in the verl source.
┌──────────────────────────────────────────────────────────────┐
│ driver (head node, no GPU) │
│ fully_async_main.py → FullyAsyncTaskRunner │
└──────────┬─────────────────────────────────────┬─────────────┘
│ .remote() │ .remote()
▼ ▼
┌──────────────────────────────────────────┐ ┌──────────────────────────────────────────┐
│ FullyAsyncRollouter (Ray actor, 0 GPU)│ │ FullyAsyncTrainer (Ray actor, 0 GPU) │
│ inherits SeparateRayPPOTrainer │ │ inherits SeparateRayPPOTrainer │
│ ┌──────────────────────────────────────┐ │ │ ┌──────────────────────────────────────┐ │
│ │ AgentLoopManager │ │ │ │ CheckpointEngineManager │ │
│ │ ├ rollout_replicas: [RolloutReplica │ │ │ │ ├ trainer wg │ │
│ │ │ × K] │ │ │ │ └ replicas: rollout_replicas │ │
│ │ ├ load_balancer: │ │ │ │ fit_step() → PPO update │ │
│ │ │ GlobalRequestLoadBalancer │ │ │ │ MetricsAggregator │ │
│ │ │ (Ray actor, sticky+LB) │ │ │ └─────────────┬────────────────────────┘ │
│ │ └ agent_loop_workers: │ │ │ │ Dispatch.DP_COMPUTE_PROTO│
│ │ AgentLoopWorker × W │ │ │ ▼ │
│ │ (Ray actors, 0 GPU, │ │ └──────────────────────────────────────────┘
│ │ hold ToolAgentLoop coroutines) │ │
│ └──────────────────────────────────────┘ │
└──────────────────────────────────────────┘
─── rollout GPU pool (rollout.nnodes × n_gpus) ─── ─── trainer GPU pool (trainer.nnodes × n_gpus) ───
┌──────────────────────────────────────────────┐ ┌──────────────────────────────────────────────┐
│ RolloutReplica r (mode = STANDALONE) │ │ RayWorkerGroup of DetachActorWorker │
│ ├ name: rollout_pool_r │ │ (extends ActorRolloutRefWorker) │
│ ├ AsyncLLMServer × (DP groups) │ │ rank 0..T-1 │
│ │ (Ray actor; vLLM AsyncLLMEngine │ │ ├ FSDP2 / Megatron model engine │
│ │ in same proc; ZMQ to ModelRunner) │ │ ├ engine.get_per_tensor_param() │
│ ├ ModelRunner × tp_size │ │ ├ compute_log_prob, update_actor │
│ │ (vLLM WorkerProc per GPU) │ │ └ checkpoint_engine = │
│ └ CheckpointEngineWorker × per GPU │ ◄══════════ NCCL bucketed broadcast ══════════════│
│ (separate process, colocated; │ │ NCCLCheckpointEngine │
│ CUDA IPC ↔ vLLM WorkerProc) │ │ (rank 0 sends; others rank -1) │
└──────────────┬───────────────────────────────┘ ├──────────────────────────────────────────────┤
▲ │ TrainingWorker (Critic, optional) │
│ └──────────────────────────────────────────────┘
│ AsyncLLMServerManager.generate.remote(prompt_ids)
│ (Ray RPC, sticky-routed via load_balancer)
│
┌─────────┴─────────────────────────────────────┐
│ ToolAgentLoop (or AsyncPartialToolAgentLoop)│
│ spawned per prompt as an asyncio coroutine│
│ inside an AgentLoopWorker │
│ │
│ state machine: │
│ PENDING ──► GENERATING ◄────┐ │
│ │ │ │
│ ▼ │ │
│ PROCESSING_TOOLS ──┘ │
│ │ │
│ ▼ │
│ TERMINATED │
│ │
│ tools (multi_turn.tool_config_path): │
│ ├── code sandbox (e.g. SandboxFusion) │
│ ├── search / browser │
│ ├── shell / file edit │
│ └── ... user-defined Tool subclasses │
│ asyncio.gather across max_parallel_calls │
│ │
│ returns AgentLoopOutput { │
│ prompt_ids, │
│ response_ids, │
│ response_mask (1 = LLM, 0 = tool tokens) │
│ } │
└──────────────────┬────────────────────────────┘
│ MessageQueueClient.put_sample(traj)
▼
┌──────────────────────────────────────────────────────────────────────────┐
│ MessageQueue (Ray actor) │
│ put_sample / get_sample, async, max_queue_size │
│ staleness control: rollout pauses when ahead by │
│ (1 + staleness_threshold) · trigger_parameter_sync_step │
│ · require_batches · ppo_mini_batch_size samples │
└────────────────────────────┬─────────────────────────────────────────────┘
│ get_sample()
└──────► FullyAsyncTrainer.fit_step()
The ══════════ NCCL bucketed broadcast ══════════ arrow between the two GPU pools is the weight-sync channel: a single NCCL process group of size rollout_world_size + 1 whose rank 0 is trainer-0 (the only rank in RayWorkerGroup<DetachActorWorker> that participates), and ranks 1..N are the rollout-side CheckpointEngineWorkers. Bucket payload travels via NCCL broadcast; bucket metadata ({name: (offset, shape, dtype)}) travels out-of-band over ZMQ PUB/SUB; the sync fires every trigger_parameter_sync_step updates. See §6.3.1 for the full walkthrough.
Three things this diagram makes concrete that prose alone can’t:
-
Two GPU pools, never colocated.
rollout.nnodes × n_gpusandtrainer.nnodes × n_gpusare separate resource pools (hybrid_engine=Falseis asserted inFullyAsyncRollouter.__init__). Weights are pushed across the cluster via NCCL, not via in-process IPC. -
The driver actors hold the agent infrastructure, but agent loops run on CPU.
FullyAsyncRollouteritself sits on no GPU; it owns anAgentLoopManagerwhich spawnsWAgentLoopWorkerRay actors (also CPU-only,max_concurrency=100). Each worker runs manyToolAgentLoopcoroutines concurrently. The GPUs in the rollout pool are owned byRolloutReplicainstances, which exposeAsyncLLMServerRay actors as the only GPU-touching surface. -
AgentLoop reaches vLLM by Ray RPC, not in-process call.
ToolAgentLoopruns inAgentLoopWorker(CPU); when it needs to generate, it callsAsyncLLMServerManager.generate(prompt_ids), which (a) acquires a server actor handle fromGlobalRequestLoadBalancer(sticky session for multi-turn), (b) issuesserver.generate.remote(prompt_ids, ...)over Ray RPC. TheAsyncLLMServeractor lives in the rollout GPU pool and forwards the request to the vLLMAsyncLLMEnginerunning in the same process; vLLM then talks to itsModelRunnerworkers over ZMQ. Token-in-token-out is preserved across all three hops — no text round-trip.
6.6.1 Lifecycle of one rollout trajectory
Following one prompt from dataset entry to gradient update, with the actual class/method names:
-
Driver creates actors.
FullyAsyncTaskRunner(infully_async_main.py) spawnsFullyAsyncRollouter.remote(...),FullyAsyncTrainer.remote(...), andMessageQueue.remote(...). The trainer’sinit_workers()spins up the trainerRayWorkerGroupofDetachActorWorkers. The rollouter’s_init_async_rollout_manager()constructsAgentLoopManager, which calls_initialize_llm_servers()to spin upKRolloutReplicas inRolloutMode.STANDALONE(each with its own resource pool) and_init_global_load_balancer()forGlobalRequestLoadBalancer. Then_init_agent_loop_workers()createsWRay-remoteAgentLoopWorkeractors. -
Rollouter pulls a prompt.
FullyAsyncRollouter._streaming_generation_main()loops;_feed_samples()reads one row at a time from the dataset;_processor_worker()schedules its processing. -
AgentLoopWorker spawns a coroutine. Per prompt, the worker calls
agent_loop = ToolAgentLoop(server_manager=AsyncLLMServerManager(...))and awaitsagent_loop.run(sampling_params, **kwargs). The state machine starts atAgentState.PENDING. -
PENDING → GENERATING.
_handle_pending_stateapplies the chat template (withtool_schemas) →prompt_ids. State advances. - GENERATING.
_handle_generating_statecallsawait self.server_manager.generate(prompt_ids, sampling_params). InsideAsyncLLMServerManager.generate:_acquire_server(request_id)— Ray RPC toGlobalRequestLoadBalancer, returns the least-loaded server’s actor handle (or the cached one if multi-turn).await server.generate.remote(prompt_ids, ...)— Ray RPC to anAsyncLLMServeractor in the rollout GPU pool.- The
AsyncLLMServercalls vLLM’sAsyncLLMEngine.generate(prompt_ids)in-process; vLLM dispatches to itsModelRunners over ZMQ; tokens stream back as aTokenOutput(token_ids,text,logprobs). _release_serverdecrements the in-flight counter.- Append
output.token_idstoagent_data.response_ids,[1] * len(output.token_ids)toagent_data.response_mask.
-
Tool calls.
ToolParser.extract_tool_calls(response_ids, tools)parses the response. If non-empty, transition toPROCESSING_TOOLS. -
PROCESSING_TOOLS.
_handle_processing_tools_statebuilds anasyncio.gatherover[self._call_tool(tc, ...) for tc in tool_calls[:max_parallel_calls]]. Tools come frommulti_turn.tool_config_path— for SWE work, that’s typically a code sandbox plusbash/file_edittools. Tool responses get tokenized and appended; mask is[0] * len(tool_tokens)(env feedback, not LLM-generated). Transition back toGENERATING. -
TERMINATED. Once
response_length >= max_response_lengthorassistant_turns >= max_assistant_turnsor no more tool calls, returnAgentLoopOutput(prompt_ids, response_ids, response_mask, ...). -
Sample enters the queue. The rollouter packs the trajectory into a
RolloutSampleand callsmessage_queue_client.put_sample(sample). If the queue is full (because rollout has produced more than(1 + staleness_threshold) · trigger_parameter_sync_step · require_batches · ppo_mini_batch_sizesamples), the rollouter blocks until the trainer drains it. -
Trainer consumes.
FullyAsyncTrainer.fit_step()calls_get_samples_from_queue(), accumulating until it hasrequire_batches × ppo_mini_batch_sizesamples, then runs the standard PPO pipeline:_fit_compute_reward → _fit_compute_log_prob → _fit_compute_advantage → _fit_update_actor. Importance ratios use the rollout-time logprobs by default (use_rollout_log_probs=True), since those samples may have been generated by an older policy version. - Periodic weight sync. Every
trigger_parameter_sync_stepupdates,_fit_update_weights()callscheckpoint_manager.update_weights(global_steps). Withpartial_rollout=True, the rollout sideabort_all_requestsfirst (vLLMsleep/save in-flight state), theCheckpointEngineManager.build_process_grouprebuilds the NCCL group, the trainer’s rank-0DetachActorWorkerstreams weights throughengine.get_per_tensor_param()into bucketed NCCL broadcasts, the rollout-sideCheckpointEngineWorkers receive and CUDA-IPC them to their colocated vLLMWorkerProcs, and finallyresume_generationbrings the in-flight conversations back online. Trainer goes back to step 10; rollouter resumes producing.
6.6.2 What runs where
For mental quick-reference:
| Component | Type | GPU? | What it does |
|---|---|---|---|
FullyAsyncTaskRunner |
driver | no | spawns the two actors + the queue |
FullyAsyncRollouter |
Ray actor | no | owns AgentLoopManager, drives streaming generation |
FullyAsyncTrainer |
Ray actor | no | runs PPO fit_step loop, owns CheckpointEngineManager |
MessageQueue |
Ray actor | no | async queue for streaming samples |
AgentLoopManager |
Python obj | no | manages replicas + workers + load balancer |
AgentLoopWorker |
Ray actor | no | runs many ToolAgentLoop coroutines concurrently |
ToolAgentLoop |
coroutine | no | per-prompt state machine with tool calls |
AsyncLLMServerManager |
Python obj | no | client-side LLM gateway with sticky sessions |
GlobalRequestLoadBalancer |
Ray actor | no | global least-loaded routing |
RolloutReplica |
wrapper | (owns GPUs) | mode=STANDALONE; one resource pool per replica |
AsyncLLMServer |
Ray actor | yes | hosts vLLM AsyncLLMEngine, exposes generate.remote() |
ModelRunner (vLLM) |
process | yes | actual GPU forward/decode worker, one per TP rank |
CheckpointEngineWorker |
Ray actor | yes (rollout side) | colocated with vLLM, NCCL-receive then CUDA-IPC weight |
RayWorkerGroup<DetachActorWorker> |
worker group | yes (trainer side) | FSDP2/Megatron training engine, gathers per-tensor params |
TrainingWorker (Critic) |
worker | yes (trainer side) | optional, only for PPO with explicit critic |
NCCLCheckpointEngine |
Python obj | n/a | bucketed NCCL broadcast, ZMQ metadata |
Two key resource isolations:
- CPU-only Ray actors carry coordination logic. Anything that does scheduling, parsing, queueing, load balancing, or running coroutines is on the CPU. This keeps GPUs saturated with actual model work.
- GPU-resident actors are kept tiny.
AsyncLLMServeris a thin Ray wrapper around vLLM’s engine;DetachActorWorkeris a thin Ray wrapper around an FSDP/Megatron training engine;CheckpointEngineWorkeris a thin Ray wrapper around an NCCL bucket forwarder. The “hot” GPU code paths (vLLM ModelRunner, training engine kernels, NCCL collectives) all run inside these wrappers, not as separate Ray hops.
This separation is what makes 671B-scale agentic RL viable: the orchestration layer is async-Python over Ray RPC and has no problem scaling to thousands of concurrent in-flight conversations, while the GPU layer only sees the operations it’s optimized for (forward/backward, generate, broadcast).
6.7 Transfer Queue — Distributed Data Pool
The Q3 roadmap’s “P1: distributed data pool” item (citing AsyncFlow) is now implemented as a separate transfer_queue package, integrated via verl/utils/transferqueue_utils.py. The motivation: in a fully-async pipeline with separate Rollouter and Trainer, the obvious architecture (Rollouter → MessageQueue → Trainer) puts every sample through a single Ray actor, which becomes a bottleneck at scale.
Transfer Queue replaces the central queue with a distributed pool. Workers publish BatchMeta (or KVBatchMeta) records — small metadata packets describing which fields are stored where — via tq.get_client().async_put(data, metadata=meta). Other workers fetch by metadata: tq.get_client().async_get_data(meta). The actual tensors travel point-to-point between the producing and consuming workers, not through the controller.
For verl, the integration is via the Dispatch system: a worker method decorated with the right dispatch mode auto-publishes its outputs into Transfer Queue and auto-fetches its inputs from there. The driver process never sees the data — it only passes around lightweight BatchMeta references. This is the same idea as Ray’s object store, but specialized for tensor batches with field-level granularity.
The Q3 roadmap’s “use tensordict and nested-tensor to remove padding and replace DataProto” is the companion change: instead of fixed-shape padded tensors flowing through DataProto, you get nested TensorDicts with variable-length sequences. Combined with Transfer Queue, the data plane finally matches the irregular nature of LLM rollouts.
6.8 Q3 Roadmap as a Map
Reading issue #2388 with this section in hand, the system-side priorities map cleanly:
- “Composable model engines” — finish the FSDP/Megatron decoupling started in PR #1977 so any engine can plug into any worker role. The HybridFlow paper’s promise, fully realized.
- “Modular rollout workers” —
VllmRolloutWorkerandSGLangRolloutWorkerexposing the same API. Weight resharding explicitly called out: “optimize tp x dp dispatch, and support receiving weight from separate resource groups”. - “Async & disaggregated architecture” — one-step-off (PR #2231 ✓), streaming/partial rollout (PR #2200), then full-async pipeline.
- “Multi-turn, data, config infra” — better message infra for multi-turn (dense reward), tensordict + nested-tensor to drop padding, distributed data pool (TransferQueue from AsyncFlow).
- “Streamline new model workflow” — abstraction for multi-modal models (currently rope handling, freeze/unfreeze, IO are inconsistent across VLMs).
- “High quality recipes” — ReTool ✓, SOTA VLM RL, DAPO at larger model sizes.
The throughline is “foundational library, not destination”: every roadmap item is about removing coupling so the community can swap pieces in and out. The HybridFlow controller (§6.1) is the why; the modular rollout workers, checkpoint engine, async pipelines, and Transfer Queue are the how.
6.9 The Practical Picture
If you put the whole stack together, a verl async PPO training step looks like:
- Rollouter spins up
AsyncLLMServers (vLLM in server mode),AgentLoopManagerschedulesToolAgentLoopinstances, each running a state machine (PENDING → GENERATING → PROCESSING_TOOLS → TERMINATED), dispatchinggenerate(prompt_ids)calls through the load balancer to whichever server is least busy. - Completed trajectories
put_sampleinto theMessageQueue(ortqdistributed pool). - Trainer pulls samples sample-by-sample, accumulating until it has
require_batches × ppo_mini_batch_sizeworth of data, then runs an inner update (PPO-clip on the rollout-time logprobs, no critic recompute whenbypass_mode=True). - Every
trigger_parameter_sync_stepupdates, the trainer hands its weights to the Checkpoint Engine (NCCL-bucketed broadcast over a pre-built process group), which streams them into the rollout workers. - With
partial_rollout=True, the rollout was paused (vLLM sleep) right before the sync — the in-flight conversations were saved with state, and resume after the new weights are loaded. No long-tail wait. - Lather, rinse, repeat.
Compared to TRL or OpenRLHF colocated PPO, verl trades implementation simplicity for resource flexibility: you can put 64 GPUs on the trainer, 64 on the rollout, and a different parallelism strategy on each side; you can swap FSDP for Megatron just by changing config; you can train DeepSeek-V3 671B because per-tensor resharding makes the Megatron↔vLLM weight transfer fit in 96 H20 GPUs of memory.
The price of this flexibility is configuration complexity (many YAML knobs) and a learning curve (HybridFlow controller + Ray + Megatron + vLLM is four mental models stacked). But for production-scale training — anything past 30B parameters or beyond a single-node setup — verl’s system design is what makes the math from Parts 1–5 actually run.
6.10 Pseudocode
Three pseudocode blocks for the core flows. Names match the actual verl code so you can grep them in the repo.
6.10.1 HybridFlow controller — the dispatch decorator
# verl/single_controller/base/decorator.py
class Dispatch(DynamicEnum):
DP_COMPUTE_PROTO = ... # split-along-batch, gather
# in a Worker
class ActorRolloutRefWorker(Worker):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, batch: DataProto) -> DataProto:
# runs on ONE DP rank with this rank's slice of the batch
return self.rollout.generate(batch)
# driver-side, in RayPPOTrainer.fit():
batch: DataProto = sample_from_dataset(...) # full batch, single-process
out: DataProto = actor_rollout_ref_wg.generate_sequences(batch)
# framework auto-splits along batch.shape[0], dispatches to each worker,
# collects + concatenates outputs. The driver writes ONE line.
6.10.2 Fully-async streaming RL loop
# verl/experimental/fully_async_policy/
# Actor 1: Rollouter (FullyAsyncRollouter)
async def _streaming_generation_main(self):
while running:
# produce samples one at a time, controlled by staleness budget
if too_far_ahead_of_trainer():
await self._resume_event.wait() # pause
sample = await llm_engine.generate(...) # one trajectory
await self.message_queue_client.put_sample(sample)
# Actor 2: Trainer (FullyAsyncTrainer)
async def fit(self):
while True:
await self.fit_step() # consumes K samples, updates once
async def fit_step(self):
batch = await self._get_samples_from_queue() # pull require_batches mini-batches
batch = self._fit_compute_reward(batch)
batch = self._fit_compute_log_prob(batch) # uses rollout_log_probs by default
batch = self._fit_compute_advantage(batch) # PPO/GRPO/RLOO/DrGRPO
batch = self._fit_update_actor(batch) # PPO-clip
self._fit_update_local_step() # local_trigger_step += 1
if self.local_trigger_step == 1: # rolled over to new param version
await self._fit_update_weights() # checkpoint_engine NCCL broadcast
# Actor 3: ParameterSynchronizer (CheckpointEngine subclass)
async def update_weights(self, global_steps):
if partial_rollout:
await rollouter.pause() # vLLM sleep, save in-flight state
weights = self.training_engine.iter_named_parameters()
await self.checkpoint_engine.send_weights(weights) # NCCL bucketed broadcast
if partial_rollout:
await rollouter.resume() # vLLM resume, continue interrupted
6.10.3 ToolAgentLoop — multi-turn tool calling
# verl/experimental/agent_loop/tool_agent_loop.py
@register("tool_agent")
class ToolAgentLoop(AgentLoopBase):
async def run(self, sampling_params, **kwargs) -> AgentLoopOutput:
agent_data = AgentData(messages=kwargs["raw_prompt"], ...)
state = AgentState.PENDING
while state != AgentState.TERMINATED:
if state == AgentState.PENDING:
# apply chat template → prompt_ids
state = await self._handle_pending_state(agent_data, sampling_params)
elif state == AgentState.GENERATING:
# call AsyncLLMServerManager.generate(prompt_ids) → response_ids
output = await self.server_manager.generate(
request_id=agent_data.request_id,
prompt_ids=agent_data.prompt_ids,
sampling_params=sampling_params,
)
agent_data.response_ids = output.token_ids
agent_data.prompt_ids += output.token_ids
agent_data.response_mask += [1] * len(output.token_ids) # 1 = LLM-generated
_, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(
agent_data.response_ids, tools
)
if agent_data.tool_calls and not at_max_turns():
state = AgentState.PROCESSING_TOOLS
else:
state = AgentState.TERMINATED
elif state == AgentState.PROCESSING_TOOLS:
# fan out tool calls in parallel, append responses
tasks = [self._call_tool(tc, agent_data.tools_kwargs, agent_data)
for tc in agent_data.tool_calls[: self.max_parallel_calls]]
responses = await asyncio.gather(*tasks)
for tool_response, _, _ in responses:
tool_tokens = self.tokenizer.encode(tool_response.text)
agent_data.prompt_ids += tool_tokens
agent_data.response_mask += [0] * len(tool_tokens) # 0 = env feedback
state = AgentState.GENERATING # back to LLM
# Final output: token sequence + loss mask. PPO trains only on mask=1 positions.
return AgentLoopOutput(
prompt_ids=agent_data.prompt_ids[: -len(agent_data.response_mask)],
response_ids=agent_data.prompt_ids[-len(agent_data.response_mask):],
response_mask=agent_data.response_mask,
...
)
The mask-based loss application is the same pattern as OpenRLHF’s action_ranges (Part 4.3): environment-feedback tokens are in the sequence (the LLM saw them as context) but not in the policy gradient.
6.11 Key Takeaways
If you only remember a few things about verl:
-
HybridFlow’s central trick: separate single-process control flow (the algorithm) from multi-process computation flow (the engines). The
@register(dispatch_mode=...)decorator hides split-dispatch-collect behind one driver-side method call. This is what lets verl swap FSDP↔Megatron and vLLM↔SGLang without rewriting the trainer. -
Per-tensor weight resharding is what makes 671B RL viable. Training and inference layouts disagree; resharding fast (bucketed CUDA IPC + ZMQ metadata + Megatron weight converters) keeps the per-step transfer time bounded, which is the difference between “fits on 96 H20” and “OOM”.
-
Checkpoint Engine is a unified weight-sync abstraction with six backends (NCCL/HCCL/NIXL/Mooncake/Kimi/naive). The same
send_weights/receive_weightsAPI papers over wildly different transports — colocated all-gather, cross-node NCCL broadcast, RDMA, MoE-aware p2p. A 235B sync drops from 58s to 24s with the right backend. -
Async training has four flavors (on-policy, stream off-policy, async with stale, async + partial rollout) controlled by three knobs (
trigger_parameter_sync_step,staleness_threshold,partial_rollout). Mode (d) — partial rollout with vLLM sleep/resume across syncs — fully eliminates the long-tail bubble and yields 2.35–2.67× wall-clock speedup at production scale. -
AgentLoop is verl’s multi-turn / tool-use abstraction.
AgentLoopBase.runreturns(prompt_ids, response_ids, response_mask)— token-in-token-out, mask determines loss application.ToolAgentLoopis a state machine (PENDING → GENERATING → PROCESSING_TOOLS → TERMINATED) with parallel tool execution.AsyncLLMServerManagerprovides load-balanced sticky-session vLLM access. This is what makes SWE-style RL training work — the same interface across single-turn, multi-turn, and tool-using rollouts. -
The Q3 roadmap’s throughline is decoupling — composable engines, modular rollout workers, distributed data pool, async pipeline, replace DataProto with TensorDict. Every system change in verl is about removing a coupling so a community contributor can swap one component without rewriting the others.
Part 7: Diagnosing RL Jobs — A Practitioner’s Runbook
The math from Parts 1–5 and the systems from Part 6 give you the capability to train an LLM with RL. They don’t tell you whether it’s working. Unlike SFT — where a smoothly decreasing loss curve is a reliable progress signal — RL training is a multi-signal system where everything can look fine on the dashboard while the model silently learns to cheat. This final part is the operating manual: what to log, what the numbers mean, what’s failing when they go wrong, and which lever to pull.
7.1 The Minimum Dashboard
Set up these metrics before the first run. Every framework we covered (TRL, OpenRLHF, verl) emits versions of them; the names below follow TRL/OpenRLHF conventions. Healthy ranges are approximate and depend on model size, vocab, and reward scale — what matters is that each metric stays stable and trends in the right direction, not the exact value.
| Metric | Healthy | Red flag | What it tells you |
|---|---|---|---|
reward/mean |
gradual increase, eventually plateaus | flat at zero, OR vertical spike in <20 steps | Primary signal of learning. Flatline = broken reward or impossible task. Vertical spike = reward hacking (§5.3). |
reward/std (per-prompt group) |
non-zero on most batches | persistently zero | All G rollouts of every prompt got the same reward. For GRPO (group_norm) this kills the gradient via /std; for group-mean variants (Dr. GRPO, REINFORCE++-baseline, RLOO) R − mean = 0 also kills the gradient. Either way: no learning signal — your data is too easy, too hard, or your reward function is broken. |
frac_reward_zero_std |
< 0.5 | → 1.0 | The fraction of prompt groups with zero reward variance, batch over batch. The ideal training curve has this dropping over time as the policy improves and harder prompts start to differentiate. |
kl (against π_ref) |
rises slowly to a stable value | < 0.005 (no learning) or > 2 (drift / collapse) | Drift from the SFT reference. KL too small means the policy isn’t moving; KL too large means it’s wandered out of the RM’s training distribution and reward signals stop being meaningful. |
entropy (per-token, nats) |
slow gradual decline | sudden plummet, OR pinned at vocab-log | “How peaked the next-token distribution is.” Sudden plummet = mode collapse (§7.2.1). Per-token entropy at training start is typically 1–3 nats and drops over RL. Specific values depend on vocab size and tokenizer. |
clip_ratio (PPO/GRPO surrogate) |
< 0.1 | > 0.3 sustained | Fraction of tokens hitting [1−ε, 1+ε] clip bounds. High = policy is trying to take large steps each update. Drop the LR or num_ppo_epochs. |
completions/mean_length |
rises gradually, plateaus | monotone rise hitting max_new_tokens |
Models naturally reason longer over RL. Pegging at the cap on >80% of completions is length hacking (§7.2.2). |
completions/clipped_ratio |
< 0.3 | > 0.7 | Fraction of trajectories truncated by max_new_tokens. Cause and consequence of length hacking. |
grad_norm (pre-clip) |
bounded, stable | repeated spikes ≫ clip threshold | Pre-clip grad norm; with max_grad_norm=1.0 clipping, post-clip is bounded but pre-clip tells you whether the clipping is actively saving you on most steps. Persistent spikes mean something upstream (advantage scale, reward outliers) is broken. |
value_loss (PPO only) |
trending down, low scale | oscillating wildly, persistently large | Critic accuracy. Bad critic → bad GAE → bad policy gradient. The critic is PPO’s most fragile component (§7.2.3). |
IS_ratio/min, IS_ratio/max (async + vLLM) |
clustered around 1 | → 0 or → ∞ | Importance-sampling correction (TIS / ICEPOP, §4.4.4). Far from 1 means the rollout sampler and the learner have drifted apart; gradient estimates are increasingly biased. Tighten weight sync interval. |
| held-out validation accuracy | tracks training reward | training reward up, val flat or down | The only metric that isn’t gameable. Run a small eval every 50–100 steps. If training reward decouples from val accuracy, you’re reward-hacking. |
A good run shows: reward and val accuracy rising together, KL drifting up to a steady plateau, entropy declining slowly, clip ratio under 10%, value loss decaying. Any of these going off the rails is the signal — Part 7.2 walks through what each failure looks like.
7.2 The Five Canonical Failure Modes
RL training fails in a small number of recognizable ways. Pattern recognition saves days.
7.2.1 Entropy collapse (mode collapse)
Symptom. entropy plummets in the first hundred steps. The model starts emitting near-identical responses for every prompt — often a high-reward template (“The answer is X” or “Let me think step by step…”). reward/mean may climb briefly before stalling.
Mechanism. PPO clipping is asymmetric in token-probability space. The upper bound 1 + ε lets a token at probability 0.9 grow to ~0.99, but only lets a token at probability 0.01 grow to ~0.012 — a tiny absolute gain. So already-likely tokens accumulate probability while exploration tokens stay flat. The distribution sharpens until exploration is dead.
Levers.
- Clip-higher (DAPO). Decouple the bounds:
epsilon_low=0.2, epsilon_high=0.28. Asymmetrically widening the upper clip gives exploration tokens room to grow when they have positive advantage. - Strengthen the KL anchor. Increase
β(or use the adaptive controller, §7.3). The reference model has full entropy; pulling toward it preserves exploration. - Token-level loss aggregation (Dr. GRPO, §4.2). Equal-weighting tokens across the batch instead of equal-weighting responses prevents the within-response averaging from amplifying short, high-confidence patterns.
7.2.2 Reward hacking
The most insidious failure because the dashboard says success. Reward goes up, training is stable, val accuracy stagnates or drops.
Length hacking. Model discovers that longer responses score slightly higher (more keywords, more guesses, RM length bias). completions/mean_length climbs monotonically until most are truncated at max_new_tokens.
- Lever: DAPO overlong penalty (§4.4.1) — soft linear ramp on responses past
max_new_tokens − overlong_buffer_len. Or ProRL stop-properly penalty (§4.4.2) for truncated samples. Or, simpler, length-normalize the reward.
Format / gibberish hacking. With rule-based rewards (RLVR), the model produces gibberish in the reasoning trace but emits a correctly-formatted answer tag (<answer>42</answer>). Reward is high; the trace is unreadable.
- Detection: training reward is high but per-token entropy spikes and KL grows fast.
- Lever: add a secondary reward signal that requires a well-formed reasoning trace (e.g., presence of
<think>...</think>with non-trivial content), and filter out questions where random guessing has a high success rate (e.g., 4-way multiple choice).
Pattern collapse / RM exploitation. Model finds a stylistic pattern (verbose, sycophantic, bullet points) that the RM consistently scores high. The RM has spurious correlations from its training data, and the policy rides them out of distribution.
- Detection: the reward distribution skews heavily right-tailed; OOD probes show the policy generating things the RM has no business scoring confidently. This is the calibration problem from §5.
- Lever: stronger KL constraint, RM ensembling with disagreement penalty, periodic RM refresh on new policy-generated data.
7.2.3 Value function divergence (PPO only)
Symptom. value_loss oscillates or stays high. Shortly after, policy performance degrades.
Mechanism. The critic is solving a hard regression: predict scalar return from a partial token sequence whose target distribution shifts every update. If it falls behind, GAE feeds the policy noise (§Part 2 Q&A on Critic objectives). Bad advantage signs → wrong-direction updates.
Levers.
- Initialize from the RM, not from scratch. The RM has the same shape as a critic head (single-output linear on top of a transformer) and a useful starting point.
- Critic warm-up. Freeze the actor for the first 100–500 steps and train only the critic until value loss stabilizes.
- Higher LR for the critic than for the actor. Common config: critic LR ≈ 9e-6, actor LR ≈ 5e-7. The critic is doing regression and needs to keep up; the actor is doing policy gradient and needs to be conservative.
- Switch to a critic-free algorithm. GRPO / Dr. GRPO / REINFORCE++ all sidestep this entirely. If the critic is your bottleneck, the cheapest fix is often to remove it.
7.2.4 Sampler-learner gap (async / vLLM mismatch)
Symptom. IS_ratio/min drops near zero or IS_ratio/max blows up. sampling_logp_difference/mean climbs above ~1. Training looks stable but performance flatlines.
Mechanism. The rollouts came from π_rollout (vLLM, possibly several updates ago); the policy gradient is being computed at π_θ (HF training engine, current). When they drift far apart, the importance-sampled gradient estimate is no longer a low-variance approximation of the true gradient — it’s biased and noisy. This is the off-policy problem we discussed in §4.4.4.
Levers.
- Reduce
staleness_threshold(ortrigger_parameter_sync_step) to sync weights more often. - Enable IS correction. TIS (token-clamp), ICEPOP (token-filter), or seq-mask-TIS — pick by how aggressive you want to be (§4.4.4). For agentic tool-use multi-turn, sequence-level filtering is usually safest because the trajectory is already long.
- Enable partial rollout (
partial_rollout=True, §4.4.2). Interrupting + resuming in-flight conversations across syncs caps how stale any single sample can be.
7.2.5 Empty-gradient batches (data issues)
Symptom. frac_reward_zero_std near 1.0 batch after batch. reward/std per-group is zero. Loss looks suspiciously stable; nothing actually improves.
Mechanism. All G rollouts of every prompt produced the same reward. For group-mean estimators (RLOO, Dr. GRPO, REINFORCE++-baseline) the centered reward R − mean = 0. For GRPO (group_norm) the /std makes it worse — division by ~0 either zeros out the advantage or blows it up. Either way, no useful gradient.
This usually means the prompts in this batch are too easy or too hard for the current policy. With binary verifiable rewards (RLVR) the failure mode is bimodal: a “trivial” prompt where all G samples succeed, or a “hopeless” prompt where all G samples fail.
Levers.
- Difficulty filtering. Pre-pass the prompt set with the current policy: estimate
Pass@1. Drop prompts withPass@1 > 0.9(too easy) andPass@1 < 0.05(too hard). Keep the middle. - DAPO dynamic sampling (§4.4.6, the
--algo.dynamic_filtering_enableflag in OpenRLHF). Oversample at generation, then drop saturated groups before backward. Composes with--rollout.vllm_generate_batch_size > --rollout.batch_size. - Increase
G(group size). With G=4, the chance of all-same-reward in a hard prompt is high; with G=16 it’s much lower. Doubling group size is cheaper than retuning anything else. - Diversify data. Mix domains (math, code, instruction, chat). A monoculture batch is more likely to all-succeed or all-fail together.
7.3 Prevention: Standard Setup
Most failed runs were preventable in setup. The following should be defaults, not optional tweaks.
Reward normalization is not optional. The raw output of a BT-trained RM is on an arbitrary scale (§5.2). Z-score whiten rewards across the batch (--reward.normalize_enable in OpenRLHF; advantage whitening in TRL/PPO; Block D in §4.2). This equalizes per-prompt contribution to the policy gradient.
Reward clipping for the long tail. Even after whitening, the occasional outlier reward causes a gradient spike. --reward.clip_range [-10, 10] or similar is a cheap insurance policy.
Adaptive KL controller. Hard-coding β doesn’t scale across model sizes. Set a target KL window (e.g., [0.05, 0.2]); when KL goes above the target, multiply β by some factor; when below, divide. This is what TRL’s AdaptiveKLController and OpenRLHF’s args.algo.kl.target give you.
Grad clipping at 1.0. Standard. Pre-clip norms of 5–10 are common during early training; clipping at 1.0 keeps you safe without cutting away too much signal.
Group size ≥ 8 for group-relative methods. With 4 rollouts per prompt, the group mean is a noisy baseline; advantages are high variance. 8 is the practical floor; 16 is a comfortable default.
Token-level loss aggregation. As discussed (§4.2), Dr. GRPO and DAPO normalize per-token across the batch instead of averaging per-sequence first. This eliminates the response-length bias and tends to make entropy and length curves more stable.
Held-out eval every 50–100 steps. A small (50–200 prompt) clean validation set is the only ground-truth signal you have. Cheap, fast, and the first place a reward-hacking failure shows up.
7.4 Pre-Flight Checklist
Before launching a run:
- Reward function smoke test. Pipe 100 sample completions through the reward function manually. Do high scores correspond to genuinely good outputs? Does an empty string get rewarded? Does a single-token answer score the same as a thoughtful response? 30 minutes here saves days of debugging.
- Effective batch size. For GRPO/RLOO/Dr.GRPO/REINFORCE++-baseline: at least 8 completions per prompt; at least 256 prompts per gradient step. For PPO: enough rollouts that
value_losswill have a stable target. - Normalization is enabled. Reward whitening, advantage whitening (PPO/GAE/REINFORCE++/REINFORCE++-baseline), or
/std(GRPO). - KL controller is configured. Target window set, initial
βreasonable for your model scale, KL is being logged every step. - Held-out eval set. 50–200 prompts, evaluated every 50–100 training steps, logged alongside training reward.
- Early-stop criteria defined upfront. “If KL > 2 for 20 steps, kill the run.” “If entropy collapses to near zero for 10 steps, kill the run.” RL runs that go off the rails do not self-correct — kill early, free the GPUs, debug.
- For PPO: critic init from RM or SFT model (never random); critic LR > actor LR; consider a critic-only warm-up phase.
- For async / disaggregated: weight sync frequency configured, IS correction (TIS / ICEPOP / seq-mask-TIS) enabled,
partial_rollout=Trueifstaleness_threshold > 0.
7.5 Algorithm-Specific Quick Reference
Tying the failure modes back to the algorithm choice:
- PPO — most resource-intensive (4 models), but has the strongest token-level credit assignment when the critic stays healthy. Single biggest risk: critic divergence. Spend disproportionate effort on critic init, warm-up, and monitoring
value_loss. - GRPO (
group_norm) — simplest critic-free option but suffers from/stddifficulty bias: prompts where all G rollouts succeed or all fail produce zero or exploding advantages (§4.2). Pair with DAPO modifications (clip-higher, dynamic sampling, token-level loss, overlong penalty) and difficulty filtering. - Dr. GRPO (
dr_grpo) — drops the/stdto fix the difficulty bias. Better default than GRPO for RLVR tasks where rewards are binary or coarsely discrete. - RLOO (
rloo) — leave-one-out unbiased baseline; particularly valuable when group size G is small (G=4–8). No global advantage whitening (§4.2). - REINFORCE++ / REINFORCE++-baseline — global batch whitening rescues critic-free training when the batch is large enough. The OpenRLHF README and ScaleRL/Magistral results report this is more stable than GRPO and faster than PPO at scale, though “stability” here is configuration-dependent — what really matters is that batch whitening absorbs reward-scale variation that group-relative methods can’t.
The shared throughline: the difference between the algorithms is small (a few lines of advantage shaping; §4.2). The difference between running them well and running them badly is big — and lives in the dashboard, the reward engineering, and the data curation. Most “GRPO doesn’t work” reports turn out to be “this batch was too easy” or “the reward function had a length bias” or “the critic never converged” — failures in §7.2 and §7.3, not in the algorithm itself.
If there’s one thing to take from this entire post, it’s that the math you implement is rarely the thing that decides whether a run succeeds. The post-training stack is now mature enough that any of the algorithms in §1, §4, and §6 will work — if you’ve set up the dashboard from §7.1, if you recognize the failure modes in §7.2 early, and if the reward signal isn’t gameable in the first place. The hard part isn’t choosing PPO vs GRPO vs Dr. GRPO. The hard part is everything around the choice.