KEMBAR78
[small][muon] Use addmm for Newton–Schulz orthogonalization by chuanhaozhuge · Pull Request #161379 · pytorch/pytorch · GitHub
Skip to content

Conversation

@chuanhaozhuge
Copy link
Contributor

@chuanhaozhuge chuanhaozhuge commented Aug 24, 2025

A performance optimization. Using torch.addmm, which fuses matrix multiply + scale + add into one op.

Benchmark
In a QWEN-like 0.5B model training we observed average optimizer.step() latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a 1.62× speedup.

matmul
Screenshot 2025-08-24 at 3 15 37 PM

addmm
Screenshot 2025-08-24 at 3 13 42 PM

Testing
End-to-end training:
We used a training script that pre-trains a QWEN-like model on openwebtext-100k dataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm.
Screenshot 2025-08-24 at 2 56 21 PM

Unit test:

    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = Muon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
        nesterov=nesterov,
        adjust_lr_fn="original",
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
        nesterov=nesterov,
        adjust_lr_fn="original",
        use_addmm=True,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)

shows numeric difference, but this is expected on bf16 precision:

Mismatched elements: 96 / 100 (96.0%)
Greatest absolute difference: 8.985400199890137e-05 at index (1, 9) (up to 1e-06 allowed)
Greatest relative difference: 0.007370449136942625 at index (0, 6) (up to 1e-05 allowed)

Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.
Update: since addmm fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to make addmm the default and only option.

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161379

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4758522 with merge base 74280d0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to make it the default implementation with addmm as it's strictly mathematically more accurate! (The differences are due to fusion, I would imagine)

@janeyx99 janeyx99 added the topic: performance topic category label Aug 25, 2025
@toothacher17
Copy link

great!

@chuanhaozhuge
Copy link
Contributor Author

briefly discussed offline with @janeyx99. since addmm fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. based on this, we opt to make addmm the default and only option.

@chuanhaozhuge chuanhaozhuge marked this pull request as ready for review August 25, 2025 20:48
@chuanhaozhuge chuanhaozhuge requested a review from albanD as a code owner August 25, 2025 20:48
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, the change to line 33 isn't needed anymore btw

@chuanhaozhuge
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 25, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / verify-cachebench-cpu-test / test (verify_cachebench, 1, 1, linux.2xlarge)

Details for Dev Infra team Raised by workflow job

@chuanhaozhuge
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…161379)

A performance optimization. Using `torch.addmm`, which fuses `matrix multiply + scale + add` into one op.

**Benchmark**
In a QWEN-like 0.5B model training we observed average `optimizer.step()` latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a **1.62×** speedup.

matmul
<img width="1403" height="600" alt="Screenshot 2025-08-24 at 3 15 37 PM" src="https://github.com/user-attachments/assets/a77a68d4-da3c-473a-97f0-e6ef0a3b46d9" />

addmm
<img width="1426" height="602" alt="Screenshot 2025-08-24 at 3 13 42 PM" src="https://github.com/user-attachments/assets/e493af36-44d3-4026-9f7c-fd0f9cdbc7e5" />

**Testing**
End-to-end training:
We used a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm.
<img width="1035" height="434" alt="Screenshot 2025-08-24 at 2 56 21 PM" src="https://github.com/user-attachments/assets/b96b13e3-0a01-4908-853c-d917b41f3d75" />

Unit test:

```python
    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = Muon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
        nesterov=nesterov,
        adjust_lr_fn="original",
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
        nesterov=nesterov,
        adjust_lr_fn="original",
        use_addmm=True,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)
```

shows numeric difference, but this is expected on bf16 precision:
```
Mismatched elements: 96 / 100 (96.0%)
Greatest absolute difference: 8.985400199890137e-05 at index (1, 9) (up to 1e-06 allowed)
Greatest relative difference: 0.007370449136942625 at index (0, 6) (up to 1e-05 allowed)
```

~~Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.~~
Update: since `addmm` fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to make `addmm` the default and only option.

Pull Request resolved: pytorch#161379
Approved by: https://github.com/janeyx99
@github-actions github-actions bot deleted the muon_dev_1 branch September 26, 2025 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: optim topic: performance topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants