KEMBAR78
fix(optim): adagrad sparse multitensor incorrect early exit by jon-chuang · Pull Request #110454 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 3, 2023

Fixes #110444 (comment)

This PR:
Passes

Main:

test/optim/test_optim.py::TestOptim::test_adagrad_sparse FAILED [0.0058s]

==================================================================================================================================== FAILURES =====================================================================================================================================
__________________________________________________________________________________________________________________________ TestOptim.test_adagrad_sparse __________________________________________________________________________________________________________________________
Traceback (most recent call last):
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/test/optim/test_optim.py", line 1448, in test_adagrad_sparse
    self._test_rosenbrock_sparse(
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/test/optim/test_optim.py", line 128, in _test_rosenbrock_sparse
    self.assertEqual(params, params_c, atol=1e-6, rtol=1e-6)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/testing/_internal/common_utils.py", line 3309, in assertEqual
    raise error_metas.pop()[0].to_error(
AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 0.09999999999993325 at index (1,) (up to 1e-06 allowed)
Greatest relative difference: 0.06249999999996089 at index (1,) (up to 1e-06 allowed)

CC: @janeyx99

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110454

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 86aecab with merge base 4069d1d (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jon-chuang jon-chuang changed the title fix(optim): adagrad sparse multitensor early exit fix(optim): adagrad sparse multitensor incorrect early exit Oct 3, 2023
if w:
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
else:
param.grad = x.to_dense()
Copy link
Collaborator Author

@jon-chuang jon-chuang Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer:

Removing to_dense speeds up the test time from 200s to 8s for adagrad. sgd_sparse is still very slow due to optimizer being slow.

Instead, we instantiate the cyclic grads directly from the grad for dense case.

@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 3, 2023
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! Thanks for the catch and the quick fix.

I would still like moving the SGD testing to its own PR just for modularity even if it is only 1 line.

@jon-chuang
Copy link
Collaborator Author

Hello @janeyx99 let's merge this to unblock #110562?

(anw apologies about not using ghstack, I'll try to get the hang of it shortly...)

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 5, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 5, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 5, 2023
… against list comprehensions (e.g. complex conversion) (#110613)

Fully fixes: #110506

Depends: #110607
Potential merge conflicts:
- #110339
- #110345
- #110454

Related:
- #110606 (we can apply the improvements here orthogonally to the complex support)

### Results

Benchmark: 100 params.

Breakdowns (float32, dynamo):
```
Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s
```

Notes:
1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
2. Adamax is not actually compiled (it is currently disabled).
3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing.

<details>

This PR:
```
Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s
```

Main
```
Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s
```

Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).

</details>

### Benchmark Script
```python
import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop

OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]

NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []

for optim_cls in OPTIMS:
    for dtype in DTYPES:
        torch._dynamo.reset()
        # torch._inductor.metrics.reset()
        input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
        )

        model(input).sum().abs().backward()
        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = torch.compile(opt_compiled.step)

        with torch.set_grad_enabled(False):
            start_time = time.time()
            compiled_step()
            summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")

        print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())

for s in summary:
    print(s)
```

CC: @janeyx99 @mlazos
Pull Request resolved: #110613
Approved by: https://github.com/janeyx99
pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2023
Follow up to: #110454, which defines the infra for sparse multi tensor optimizer testing

Pull Request resolved: #110562
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[optim]: Adagrad with sparse grad is not on multi_tensor testing path + bug

5 participants