This guide details the Group Relative Policy Optimization (GRPO) implementation within NeMo RL. We walk through data handling, policy model training, fast generation, and the GRPO loss function.
To get started quickly, use the script examples/run_grpo.py, which demonstrates how to train a model on math problems using GRPO. You can launch this script locally or through Slurm. For detailed instructions on setting up Ray and launching a job with Slurm, refer to the cluster documentation.
We recommend launching the job using uv:
uv run examples/run_grpo.py --config <PATH TO YAML CONFIG> {overrides}If not specified, config will default to examples/configs/grpo_math_1B.yaml.
Reminder: Do not forget to set your HF_HOME, WANDB_API_KEY, and HF_DATASETS_CACHE (if needed). You'll need to do a huggingface-cli login as well for Llama models.
In this guide, we'll walk through how we handle:
- Data
- Model training
- Fast generation
- Overall resource flow
- Loss
We support training with multiple RL "Environments" at the same time.
An Environment is an object that accepts a state/action history and returns an updated state and rewards for the step. They run as Ray Remote Actors. Example MathEnvironment.
To support this, we need to know:
- What environments you have
- Which data should go to which environments
- How to prepare the data from your dataset into a form we can use
GRPO datasets in NeMo RL are encapsulated using classes. Each GRPO data class is expected to have the following attributes:
dataset: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below.task_name: A string identifier that uniquely identifies the dataset.
GRPO datasets are expected to follow the HuggingFace chat format. Refer to the chat dataset document for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. response_datasets/deepscaler.py has an example:
Note: The task_name field is required in each formatted example.
def format_data(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"messages": [
{"role": "user", "content": data["problem"]},
{"role": "assistant", "content": data["answer"]},
],
"task_name": self.task_name,
}By default, NeMo RL has some built-in supported datasets (e.g., OpenAssistant, OpenMathInstruct-2, Squad, etc.). You can see the full list here. All of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
We provide a ResponseDataset class that is compatible with JSONL-formatted response datasets for loading datasets from local path or Hugging Face. You can use input_key, output_key to specify which fields in your data correspond to the question and answer respectively. Here's an example configuration:
data:
# other data settings, see `examples/configs/grpo_math_1B.yaml` for more details
...
# dataset settings
train:
# this dataset will override input_key and use the default values for other vars
data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace)
input_key: question
subset: null # used for HuggingFace datasets
split: train # used for HuggingFace datasets
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: 42 # seed for train/validation split when split_validation_size > 0
validation:
# this dataset will use the default values for other vars except data_path
data_path: /path/to/local/val_dataset.jsonl
default:
# will use below vars as default values if dataset doesn't specify it
dataset_name: ResponseDataset
input_key: input
output_key: output
prompt_file: null
system_prompt_file: null
processor: "math_hf_data_processor"
env_name: "math"Your JSONL files should contain one JSON object per line with the following structure:
{
"input": "Hello", // <input_key>: <input_content>
"output": "Hi there!" // <output_key>: <output_content>
}We support using multiple datasets for train and validation. You can refer to examples/configs/grpo_multiple_datasets.yaml for a full configuration example. Here's an example configuration:
data:
_override_: true # override the data config instead of merging with it
# other data settings, see `examples/configs/grpo_math_1B.yaml` for more details
...
# dataset settings
train:
# train dataset 1
- dataset_name: OpenMathInstruct-2
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: 42 # seed for train/validation split when split_validation_size > 0
# train dataset 2
- dataset_name: DeepScaler
validation:
# validation dataset 1
- dataset_name: AIME2024
repeat: 16
# validation dataset 2
- dataset_name: DAPOMathAIME2024
# default settings for all datasets
default:
...We support using a single dataset for both train and validation by using split_validation_size to set the validation ratio.
OpenAssistant, OpenMathInstruct-2, ResponseDataset, Tulu3SftMixtureDataset are supported for this feature.
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like ResponseDataset.
# `self.val_dataset` is used (not None) only when current dataset is used for both training and validation
self.val_dataset = None
self.split_train_validation(split_validation_size, seed)We define a DatumSpec that holds all relevant information for each training example:
class DatumSpec(TypedDict):
message_log: LLMMessageLogType
length: int # total (concatenated) length of the message tensors
extra_env_info: dict[str, Any] # anything your environment requires goes here, for example the 'answer' of a math problem
loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid)
idx: int
task_name: Optional[str] = "default"
__extra__: Any # This allows additional fields of any typeWe refer to each distinct environment your model aims to optimize against as a "task." For example, you might define tasks like "math" or "code."
For each task, you should provide a data processor that reads from your dataset and returns a DatumSpec.
def my_data_processor(
datum_dict: dict[str, Any], # loaded directly from your dataset (that is, a single line of JSONL data)
task_data_spec: TaskDataSpec,
tokenizer,
max_seq_length: int,
idx: int,
) -> DatumSpec:We have an example of this as math_data_processor in processors.py.
By default, NeMo RL uses a single dataloader that aggregates data from multiple datasets. For scenarios requiring fine-grained control over the number of prompts loaded from each dataset, NeMo RL provides support for multiple dataloaders.
The following example demonstrates how to configure multiple dataloaders:
uv run examples/run_grpo.py \
--config examples/configs/grpo_multiple_datasets.yaml \
grpo.num_prompts_per_step=32 \
data.use_multiple_dataloader=true \
data.num_prompts_per_dataloader=16 \
data.custom_dataloader=examples.custom_dataloader.custom_dataloader.example_custom_dataloaderFor example, consider using example_custom_dataloader, which samples data from each dataloader sequentially.
Given two datasets:
- Dataset 1:
[a, b, c, d] - Dataset 2:
[1, 2, 3, 4, 5, 6, 7, 8]
With data.use_multiple_dataloader=false and grpo.num_prompts_per_step=4:
Batch 1: [a, b, c, d]
Batch 2: [1, 2, 3, 4]
Batch 3: [5, 6, 7, 8]
With data.use_multiple_dataloader=true, grpo.num_prompts_per_step=4, and data.num_prompts_per_dataloader=2:
Batch 1: [a, b, 1, 2]
Batch 2: [c, d, 3, 4]
Batch 3: [a, b, 5, 6]
Custom Dataloader
The file examples/custom_dataloader/custom_dataloader.py provides a reference implementation that samples data.num_prompts_per_dataloader entries from each dataloader.
When a single dataloader is exhausted, the data iterator must be reset in the custom dataloader function (as demonstrated in examples/custom_dataloader/custom_dataloader.py).
This design ensures that the MultipleDataloaderWrapper operates as an infinite iterator, where __next__() will not raise StopIteration and __len__() is not supported.
Additionally, custom dataloaders can access recorded metrics from the training loop. Use wrapped_dataloader.set_records() in nemo_rl/algorithms/grpo.py to store relevant information, which can then be retrieved in your custom dataloader implementation:
# In nemo_rl/algorithms/grpo.py
wrapped_dataloader.set_records({"reward": ...})
# In custom_dataloader.py
def example_custom_dataloader(
data_iterators: dict[str, Iterator],
dataloaders: dict[str, StatefulDataLoader],
**kwargs,
) -> tuple[BatchedDataDict, dict[str, Iterator]]:
...
reward = kwargs["reward"]
...num_prompts_per_dataloader
This parameter specifies the number of prompts generated by each dataloader per iteration. Ensure that grpo.num_prompts_per_step is a multiple of data.num_prompts_per_dataloader to guarantee that exactly grpo.num_prompts_per_step prompts are available for each training step.
- task_name (unique task identifier):
- Determines which processor, env, prompts, and dataset to use for this task.
- Currently, we support a single dataset and a single environment. Therefore, task_name equals the dataset_name in the config (i.e., config.data.dataset_name).
- task_spec (TaskDataSpec):
- Specifies per-task system prompt and prompt.
- task_data_processors:
- Dict mapping: task_name -> (task_spec, processor_fn).
- task_to_env:
- Dict mapping: task_name -> task_env.
Example (simplified):
task_data_processors = {data.task_name: (data.task_spec, data.processor)}
task_to_env = {data.task_name: env}GRPO expects datasets to have the following form:
{"task_name": "math", /* actual data */}Then, you can set the data up as follows:
# 1) Setup environments from data config
env_name_list = extract_necessary_env_names(data_config)
envs = {
env_name: create_env(env_name=env_name, env_config=env_configs[env_name])
for env_name in env_name_list
}
# 2) Load dataset using the helper (built-ins or local/HF datasets)
data = load_response_dataset(data_config["train"])
# 3) Build task mapping
task_data_processors = {data.task_name: (data.task_spec, data.processor)}
task_to_env = {data.task_name: envs[data_config["train"]["env_name"]]}
# 4) Construct processed dataset
dataset = AllTaskProcessedDataset(
data.dataset,
tokenizer,
None,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
# 5) Do the same thing for validation dataset if it exists
if "validation" in data_config and data_config["validation"] is not None:
val_data = load_response_dataset(data_config["validation"])
val_task_data_processors = {val_data.task_name: (val_data.task_spec, val_data.processor)}
val_task_to_env = {val_data.task_name: envs[data_config["validation"]["env_name"]]}
val_dataset = AllTaskProcessedDataset(
val_data.dataset,
tokenizer,
None,
val_task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)Ensure you provide a mapping of tasks to their processors so the dataset knows which processor to use when handling samples.
GRPO supports various types of environments for different tasks, including Math, Code, and Reward Model environments. Each environment provides a standardized interface for reward computation and evaluation, enabling consistent training across diverse domains.
For more information about environments, see the Environments Guide.
- env:
- The environment actor for reward/evaluation, constructed using
create_env(env_name=..., env_config=...). - The environment to use is declared under the data section of the config (e.g.,
data.env_namestates which env the dataset uses).
- The environment actor for reward/evaluation, constructed using
- task_to_env:
- Dict mapping: task_name -> env. In the current single-task setup this typically points all tasks to the same env, but this structure enables different envs per task in future multi-task scenarios.
Example (simplified):
env_name_list = extract_necessary_env_names(data_config)
envs = {
env_name: create_env(env_name=env_name, env_config=env_configs[env_name])
for env_name in env_name_list
}
task_to_env[task_name] = envs[data_config["train"]["env_name"]]
val_task_to_env = task_to_env # validation usually mirrors training mappingWe define a {py:class}~nemo_rl.models.policy.interfaces.PolicyInterface that contains everything you need to train a Policy model.
This Policy object holds a RayWorkerGroup of SPMD (1 proc/GPU) processes that run HF/MCore, all coordinated by this object so it appears to you like 1 GPU!
We support vLLM through the VllmGeneration class right now.
The function, grpo_train, contains the core GRPO training loop.
RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the design doc.
We use the ClippedPGLossFn to calculate the loss for GRPO. Formally,
where:
-
$\pi_\theta$ is the policy model we are currently optimizing -
$\pi_{\theta_{\text{old}}}$ is the previous policy model (from the beginning of this step) -
$A_t$ is the advantage estimate -
$\varepsilon$ is a clipping hyperparameter -
$\beta$ is the KL penalty coefficient -
$\pi_{\text{ref}}$ is the reference policy
It also supports "Dual-Clipping" from Ye et al. (2019), which
imposes an additional upper bound on the probability ratio when advantages are negative.
This prevents excessive policy updates.
where:
- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set to 3 empirically.
-
$r_t(\theta)$ is the ratio$\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has changed.
This feature is controlled by the parameter use_on_policy_kl_approximation. It enables the use of an estimator for KL divergence based on Schulman (2020), which is both unbiased and guaranteed to be positive.
Note that the loss function above samples from
To enable the on-policy KL approximation, set the config use_on_policy_kl_approximation=True in the ClippedPGLossConfig. By default, we set this config to False to align with standard GRPO.
This feature is controlled by the parameter use_importance_sampling_correction. It applies importance sampling to adjust for discrepancies between the behavior policy and the target policy, improving the accuracy of off-policy estimates. The policy we use to draw samples,
Let
By multiplying the first term of the loss function by the importance weights
To enable the importance sampling correction, set the config use_importance_sampling_correction=True in the ClippedPGLossConfig. By default, we set this config to False to align with standard GRPO.
This feature is controlled by the parameter overlong_filtering. It filters out sequences that exceed a predefined maximum length, helping maintain computational efficiency and model stability. When overlong_filtering=True, samples that reach max_total_sequence_length without producing an end-of-text token are excluded from loss computation. This reduces noise from penalizing generations that may be high-quality but exceed the sequence length limit.
The implementation modifies the loss calculation as follows:
For each sample
The sample mask becomes (let m_i denote the sample mask and ℓ_i denote the loss multiplier):
This results in the effective loss:
where
To configure:
grpo:
overlong_filtering: false # defaultSet overlong_filtering to true when training on tasks where truncation at the maximum sequence length is expected, such as long-form reasoning or mathematical proofs.
This feature is controlled by the parameters wandb_name and tb_name. We track a few metrics during training for scientific experimentation and to validate correctness as the run progresses.
This feature is controlled by the parameter token_mult_prob_error. It measures the error introduced when token probabilities are scaled multiplicatively, which can affect model calibration and output consistency. This is equal to the 'Logprob consistency metric' defined in Adding New Models:
Intuitively, this measures the average multiplicative probability error for sampled tokens, where samples are drawn as
This feature is controlled by the following metrics:
-
gen_kl_error:$D_{\text{KL}}(P_{gen} || P_{policy})$ - the generation distribution as ground truth
-
policy_kl_error:$D_{\text{KL}}(P_{policy} || P_{gen})$ - the policy (training) distribution as ground truth
-
js_divergence_erroror (Jensen–Shannon divergence):$(D_{\text{KL}}(P_{policy} || P_{m}) + D_{\text{KL}}(P_{gen} || P_{m})) / 2$ , where$P_{m} = (P_{policy} + P_{gen}) / 2$ - uses the mean mixture distribution as reference
According to the paper When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch, gen_kl_error was introduced (referred to as vllm-kl in the paper) as the key metric to measure the mismatch between the policy and generation distributions. Empirically, the mismatch is approximately 1e-3, and the divergence is larger for low-probability tokens as predicted by the generation inference engine (like vLLM).
The three divergence metrics provide complementary perspectives on distribution mismatch. For example:
We observed a case where vLLM assigned a disproportionately high probability to a single rare token, causing significant logprob error spikes (especially in MoE architectures):
# extreme example
1. Position 4559: 'au' (ID: 1786)
logp_gen (from vLLM): -5.xxx
logp_policy (from Mcore): -15.xxx
Assuming other tokens have near-zero divergence, this single token's metrics with kl_type=k3 are:
gen_kl_error: exp(-15 + 5) - (-15 + 5) - 1 ≈ 9 (moderate mismatch)policy_kl_error: exp(-5 + 15) - (-5 + 15) - 1 ≈ 22,015 (severe mismatch dominating the metric)js_divergence_error: ≈ 9, close togen_kl_errorsince the mixture distribution (~-5.69) is dominated by the higher-probability value (logp_gen in this example)
Ideally, all KL divergence metrics should be close to 0, with values below 1e-3 considered acceptable. Investigate any metric that shows spikes above this threshold.
This feature is controlled by the parameter sampling_importance_ratio. It adjusts the weighting of samples based on the ratio between the target policy and the behavior policy, helping to correct for distributional shift in off-policy learning. Not to be confused with the clipped importance ratio in PPO/GRPO, this is the importance ratio between
This is simply
Similar to Multiplicative Token Probability Error, this is a measure of how far off your inference backend is from your training framework. However, this metric is meant to find the bias in that error, rather than the variance, as it does not take the absolute value of the error. With some noise, this should hover around 1.
This metric is always calculated and the per-token version (without the mean) is used in the loss function when Importance Sampling Correction is enabled.
This feature is controlled by the parameter approx_entropy. It estimates the entropy of the policy distribution, which can be used to encourage exploration and prevent premature convergence during training. We roughly approximate the entropy of the LLM's distribution throughout training by calculating:
This expectation is estimated using the rollouts in each global training batch as Monte Carlo samples. The ratio of
We use this to track if our models are experiencing entropy collapse too quickly during training (as is quite common). This is a fairly rough Monte Carlo approximation, so we wouldn't recommend using this directly for an entropy bonus or otherwise backpropagating through this. You can take a look at NeMo Aligner's implementation of a full entropy calculation if you're interested (work-in-progress efficient calculation in NeMo RL).
GRPO supports LoRA on the NeMoRL DTensor backend. The LoRA settings live under policy.dtensor_cfg.lora_cfg, and the fields follow the SFT LoRA configuration. For DTensor parameter details, see SFT LoRA: DTensor Configuration Parameters. To enable LoRA, set policy.dtensor_cfg.lora_cfg.enabled=true, then configure target modules, rank, alpha, and dropout as needed.
Our DTensor LoRA path uses a merge-weight approach: during generation, LoRA adapter weights are merged into the base linear weights. This improves performance, with a small training-inference mismatch that we consider acceptable. If you require strict training-inference parity, use the split-weight variant branch, which may trade off some performance. For a comparison between merge-weight and split-weight, see PR 1797: Support lora in dtensor grpo workflow by merging weight.
We already provide a DTensor-based Nano v3 GRPO LoRA recipe. See grpo-nanov3-30BA3B-2n8g-fsdp2-lora.yaml for an end-to-end example.
Upon completion of the training process, you can refer to our evaluation guide to assess model capabilities.