-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[small][muon] Use addmm for Newton–Schulz orthogonalization #161379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161379
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4758522 with merge base 74280d0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok to make it the default implementation with addmm as it's strictly mathematically more accurate! (The differences are due to fusion, I would imagine)
|
great! |
|
briefly discussed offline with @janeyx99. since |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, the change to line 33 isn't needed anymore btw
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / verify-cachebench-cpu-test / test (verify_cachebench, 1, 1, linux.2xlarge) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…161379) A performance optimization. Using `torch.addmm`, which fuses `matrix multiply + scale + add` into one op. **Benchmark** In a QWEN-like 0.5B model training we observed average `optimizer.step()` latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a **1.62×** speedup. matmul <img width="1403" height="600" alt="Screenshot 2025-08-24 at 3 15 37 PM" src="https://github.com/user-attachments/assets/a77a68d4-da3c-473a-97f0-e6ef0a3b46d9" /> addmm <img width="1426" height="602" alt="Screenshot 2025-08-24 at 3 13 42 PM" src="https://github.com/user-attachments/assets/e493af36-44d3-4026-9f7c-fd0f9cdbc7e5" /> **Testing** End-to-end training: We used a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm. <img width="1035" height="434" alt="Screenshot 2025-08-24 at 2 56 21 PM" src="https://github.com/user-attachments/assets/b96b13e3-0a01-4908-853c-d917b41f3d75" /> Unit test: ```python # dummy model and data model0 = Linear(10, 10, bias=False) model1 = copy.deepcopy(model0) inputs = torch.randn(8, 10) targets = torch.randn(8, 10) loss = MSELoss() lr = 1e-3 wd = 0.1 momentum = 0.95 opt_ref_muon = Muon( params=model0.parameters(), lr=lr, weight_decay=wd, momentum=momentum, nesterov=nesterov, adjust_lr_fn="original", ) opt_exp_muon = Muon( params=model1.parameters(), lr=lr, weight_decay=wd, momentum=momentum, nesterov=nesterov, adjust_lr_fn="original", use_addmm=True, ) out_ref = model0(inputs) loss_ref = loss(out_ref, targets) opt_ref_muon.zero_grad() loss_ref.backward() opt_ref_muon.step() out_exp = model1(inputs) loss_exp = loss(out_exp, targets) opt_exp_muon.zero_grad() loss_exp.backward() opt_exp_muon.step() for p_ref, p_exp in zip(model0.parameters(), model1.parameters()): torch.testing.assert_close(p_ref, p_exp) ``` shows numeric difference, but this is expected on bf16 precision: ``` Mismatched elements: 96 / 100 (96.0%) Greatest absolute difference: 8.985400199890137e-05 at index (1, 9) (up to 1e-06 allowed) Greatest relative difference: 0.007370449136942625 at index (0, 6) (up to 1e-05 allowed) ``` ~~Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.~~ Update: since `addmm` fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to make `addmm` the default and only option. Pull Request resolved: pytorch#161379 Approved by: https://github.com/janeyx99
A performance optimization. Using
torch.addmm, which fusesmatrix multiply + scale + addinto one op.Benchmark
In a QWEN-like 0.5B model training we observed average
optimizer.step()latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a 1.62× speedup.matmul

addmm

Testing

End-to-end training:
We used a training script that pre-trains a QWEN-like model on
openwebtext-100kdataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm.Unit test:
shows numeric difference, but this is expected on bf16 precision:
Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.Update: since
addmmfuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to makeaddmmthe default and only option.