fix: skip loading reference model when KL penalty is zero#2178
fix: skip loading reference model when KL penalty is zero#2178
Conversation
When reference_policy_kl_penalty is 0, the reference model is unused during GRPO training. Pass init_reference_model=False to avoid allocating memory for the reference model weights. Closes #1957 Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
| policy_config["megatron_cfg"]["train_iters"] = total_train_iters | ||
|
|
||
| # Define initialization functions that will be used in all paths | ||
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 |
There was a problem hiding this comment.
shall we set skip_reference_policy_logprobs_calculation to True in this situation? otherwise I guess we will get error when calling get_reference_policy_logprobs.
and I think it's better to add a functional test (or modify one exist functional test) for reference_policy_kl_penalty == 0.
| policy_config["megatron_cfg"]["train_iters"] = total_train_iters | ||
|
|
||
| # Define initialization functions that will be used in all paths | ||
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 |
There was a problem hiding this comment.
BUG: Setting init_reference_model=False here prevents reference model weights from being loaded, but the sync training loop (line 1754) still calls policy.get_reference_policy_logprobs() unless grpo.skip_reference_policy_logprobs_calculation is explicitly True.
When reference_policy_kl_penalty=0 and the skip flag is unset, use_reference_model() accesses self.reference_model_state_dict which was never initialized → AttributeError.
Multiple existing configs are affected:
examples/nemo_gym/grpo_nanov3.yamlexamples/configs/recipes/llm/dapo-qwen2.5-7b.yamlexamples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yamlexamples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml
All have reference_policy_kl_penalty: 0 without setting skip_reference_policy_logprobs_calculation: true.
Suggested fix — auto-derive the skip flag:
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 | |
| init_reference_model = master_config["loss_fn"]["reference_policy_kl_penalty"] > 0 | |
| # Auto-skip reference logprob calculation when reference model is not loaded | |
| if not init_reference_model: | |
| master_config["grpo"]["skip_reference_policy_logprobs_calculation"] = True | |
Bug: async GRPO path missing reference logprob skip guardThe async GRPO path at This needs the same guard as the sync path (line 1754). Re: @yuki-97's commentGreat catch — both points are valid:
Generated by Claude Code |
When reference_policy_kl_penalty is 0, the reference model is unused during GRPO training. Pass init_reference_model=False to avoid allocating memory for the reference model weights.
Closes #1957
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information