fix: skip_reference_policy_logprobs_calculation=true crashes training#2174
Open
ShriyaRishab wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Open
fix: skip_reference_policy_logprobs_calculation=true crashes training#2174ShriyaRishab wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
ShriyaRishab wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
Fixes NVIDIA-NeMo#1968: Setting skip_reference_policy_logprobs_calculation=true with reference_policy_kl_penalty=0 crashed training in three ways: Bug 1: use_reference_model() context manager crash when reference model was never initialized (AttributeError on reference_state_dict). Fix: Added early-return guard in use_reference_model() for all three worker types (megatron, dtensor v1, dtensor v2) - yields without swapping when reference model is None/missing. Bug 2: Async GRPO path unconditionally called get_reference_policy_logprobs() without checking the skip flag. Fix: Added the same skip guard as the sync path, setting zeros_like for reference_policy_logprobs when skipping. Bug 3: Missing reference_policy_logprobs key in train_data causing shape mismatches downstream in loss computation. Fix: Both sync and async paths now explicitly set train_data['reference_policy_logprobs'] = zeros_like(prev_logprobs) when skipping. Also added a _has_reference_model() helper and zeros fallback in base_policy_worker.get_reference_policy_logprobs() as defense-in-depth.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Summary
Setting
skip_reference_policy_logprobs_calculation=truein GRPO config crashes because:reference_policy_logprobsis never assigned totrain_datawhen skippeduse_reference_model()context manager crashes when no reference state dict existsFixes #1968
Root Cause
Three code paths needed fixes:
grpo.pysync path: missingtrain_data["reference_policy_logprobs"]assignmentgrpo.pyasync path: sameuse_reference_model()tries to swap non-existent state dictsFix
torch.zeros_like(prev_logprobs)toreference_policy_logprobs_has_reference_model()base methodget_reference_policy_logprobs(): return zeros if no reference modeluse_reference_model()context managers: yield without swapping if no reference state dictIssues
List issues that this PR closes (syntax):
#1968
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information