feat: Megatron LoRA GRPO w/ Weight Merging#1889
Conversation
📝 WalkthroughWalkthroughThis PR updates Megatron-LM and Megatron-Bridge submodule references to new repositories and commits, updates dependencies in setup.py files to support newer versions, introduces PEFT/LoRA configuration blocks for GRPO, adds a state dict remapping utility method to MegatronPolicyWorker, and introduces new functional test scripts for Megatron-LORA GRPO training experiments. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/policy/workers/megatron_policy_worker.py (1)
1565-1588:⚠️ Potential issue | 🟠 MajorManual state_dict swap will KeyError on LoRA-only keys.
When
use_peft=True,self.model.state_dict()contains LoRA adapter keys (e.g.,lora_A,lora_B) that don't exist inself.reference_state_dict. Line 1569 (self.reference_state_dict[k]) and line 1588 (model_state_dict[k]) will raiseKeyErrorfor these keys.The commented-out
load_state_dictwithstrict=True(lines 1565, 1585) had the same problem, which is presumably why it was replaced. But the manual approach needs to handle missing keys too.Proposed fix
# Swap reference model state_dict to self.model for k, v in self.model.state_dict().items(): if isinstance(v, torch.Tensor): - v.copy_(self.reference_state_dict[k]) + if k in self.reference_state_dict: + v.copy_(self.reference_state_dict[k])- for k, v in self.model.state_dict().items(): - if isinstance(v, torch.Tensor): - v.copy_(model_state_dict[k]) + for k, v in self.model.state_dict().items(): + if isinstance(v, torch.Tensor) and k in model_state_dict: + v.copy_(model_state_dict[k])
🤖 Fix all issues with AI agents
In @.gitmodules:
- Around line 3-4: Update the submodule declaration that currently sets "url =
https://github.com/yaoyu-33/Megatron-LM.git" (with "branch = main") so it points
to the official upstream "https://github.com/NVIDIA/Megatron-LM.git"; if the
fork is intentionally required instead, replace the URL only after adding a
short justification in the repo docs (e.g., SECURITY.md or README) explaining
why the fork is needed and noting any maintained diffs, and include a maintainer
sign-off in the justification so reviewers can accept the deviation.
In `@3rdparty/Megatron-LM-workspace/Megatron-LM`:
- Line 1: The submodule 3rdparty/Megatron-LM-workspace/Megatron-LM points at a
non-existent commit (11dcbaca317133cc5c77c8bc4f54ed71d3b5d656); update the
submodule to a valid commit/branch on the upstream Megatron-LM remote by
entering the submodule (cd 3rdparty/Megatron-LM-workspace/Megatron-LM), running
git fetch origin, checking out a known-good commit or branch (e.g., origin/main
or a specific existing SHA), then git add the submodule change in the
superproject, commit the update, and push the branch so the PR references a
valid submodule commit.
In `@examples/configs/grpo_math_1B_megatron_lora.yaml`:
- Line 114: Replace the YAML value that sets lora_dtype so it yields a true null
rather than the string "None": change the mapping key/value where lora_dtype is
defined (currently `lora_dtype: None`) to use YAML null (`lora_dtype: null` or
`lora_dtype: ~`) so that downstream code constructing LoRA (e.g.,
LoRA(lora_dtype=...)) receives a null/None value instead of the string "None".
In `@examples/configs/grpo_math_1B_megatron.yaml`:
- Around line 100-111: The base Megatron config currently enables LoRA by
default and sets lora_dtype to the literal string "None"; change peft.enabled to
false so downstream non‑LoRA runs (e.g., grpo_megatron.sh) don't inadvertently
enable LoRA, and have LoRA-specific configs or the grpo_megatron_lora.sh
override set peft.enabled=true when needed; also replace lora_dtype: None with a
YAML null (e.g., lora_dtype: null or lora_dtype: ~) so it parses as null rather
than the string "None".
In `@nemo_rl/models/policy/workers/megatron_policy_worker.py`:
- Around line 925-933: The current check uses "if ref_megatron_cfg is not None"
which is always true because ref_megatron_cfg is always created; change the
guard to verify PEFT is enabled (e.g. if self.use_peft and ref_megatron_cfg is
not None) before creating and registering the PEFT pre-wrap hook via
_create_peft_pre_wrap_hook(ref_megatron_cfg, ref_state), calling
ref_megatron_cfg.model.register_pre_wrap_hook(pre_peft_hook), composing
composed_peft_hook, and extending ref_pre_wrap_hooks so LoRA wrapping only
applies when self.use_peft is true.
- Around line 946-960: When self.use_peft is true the current
should_load_checkpoint only checks ref_megatron_cfg.checkpoint.load and ignores
ref_megatron_cfg.checkpoint.pretrained_checkpoint; update the PEFT branch in
megatron_policy_worker.py so should_load_checkpoint mirrors the non-PEFT logic
by checking both ref_megatron_cfg.checkpoint.load and
ref_megatron_cfg.checkpoint.pretrained_checkpoint with checkpoint_exists, and
preserve the existing ref_megatron_cfg.checkpoint.finetune toggling behavior
(still set finetune=False when loading a checkpoint) so the reference model
loads pretrained weights in PEFT scenarios.
🧹 Nitpick comments (2)
nemo_rl/models/policy/workers/megatron_policy_worker.py (2)
1571-1573: Commented-out code without explanation.Per coding guidelines, commented-out code should include a comment describing why it is retained, or be removed before merging. Lines 1571-1573 and 1565 have commented-out
load_state_dictcalls with no rationale.
904-943: Duplicate LoRA construction — extract a helper.The LoRA instantiation block (lines 906-919) is nearly identical to lines 308-320 in
setup_megatron_model. Consider extracting a shared helper to avoid copy-paste divergence.
7c9e021 to
4c0e216
Compare
❌ Submodule Fast-Forward Check FailedCheck based on commit: 4c0e216 (PR #1889 from ✅ Submodules that are properly updated:Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-LM: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 2b335f7 (PR #1889 from ✅ Submodules that are properly updated:Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-LM: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
2a926f7 to
8168855
Compare
examples/configs/recipes/llm/grpo-nanov3-30BA3B-1n8g-megatron-lora.yaml
Outdated
Show resolved
Hide resolved
terrykong
left a comment
There was a problem hiding this comment.
another round of review
|
also, DCO and lint need to be resolved before final merge |
3bb86a6 to
c799cdf
Compare
9c7e4e2 to
d5fa658
Compare
|
Fixed DCO and ran the linter. Some files I didn't touch for this PR also had small lint fixes. |
terrykong
left a comment
There was a problem hiding this comment.
@yaoyu-33 @ananthsub can you take a pass? some of the megatron changes could use your expertise
|
@terrykong Does anything else need to be done here? |
|
@vadam5 this PR is all good. i just merged in another PR to create another CI level to help speed up the evaluation of PRs, i'll help resolve this and kick off those tests |
abab66f to
34726e4
Compare
34726e4 to
0ce335f
Compare
Signed-off-by: Terry Kong <terryk@nvidia.com>
0ce335f to
83b2d9a
Compare
Signed-off-by: Anna Shors <ashors@nvidia.com> Signed-off-by: Virginia Wu <vadams@nvidia.com> Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Anna Shors <ashors@nvidia.com> Co-authored-by: root <root@pool0-00689.cm.cluster> Co-authored-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Anna Shors <ashors@nvidia.com> Signed-off-by: Virginia Wu <vadams@nvidia.com> Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Anna Shors <ashors@nvidia.com> Co-authored-by: root <root@pool0-00689.cm.cluster> Co-authored-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Anna Shors <ashors@nvidia.com> Signed-off-by: Virginia Wu <vadams@nvidia.com> Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Anna Shors <ashors@nvidia.com> Co-authored-by: root <root@pool0-00689.cm.cluster> Co-authored-by: Terry Kong <terryk@nvidia.com>
GRPO LoRA for Megatron Core has landed (#1889), so remove the "coming soon" note and reword the LoRA news bullets for consistency. Signed-off-by: Terry Kong <terryk@nvidia.com>
What does this PR do ?
Supports sync, async, and non-colocated LoRA GRPO via the megatron path with weight merging for rollouts. This PR merges lora adapter weights into model weights before exporting to VLLM for rollouts.
Issues
closes #1372
closes #1371
closes #833
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
https://wandb.ai/nvidia/nemo-rl?nw=s1m0n39d4le
Summary by CodeRabbit
New Features
Tests
Chores