KEMBAR78
perf(inductor): use for loop with shortcut in `Optimizer`s to speedup against list comprehensions (e.g. complex conversion) by jon-chuang · Pull Request #110613 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 5, 2023

Fully fixes: #110506

Depends: #110607
Potential merge conflicts:

Related:

Results

Benchmark: 100 params.

Breakdowns (float32, dynamo):

Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s

Notes:

  1. Adagrad is still slow due to _get_value list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
  2. Adamax is not actually compiled (it is currently disabled).
  3. Inductor compile time is quite variable. We calculate dynamo by subtracting call_user_compiler from compile_inner timing.

This PR:

Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s

Main

Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s

Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).

Benchmark Script

import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop

OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]

NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []

for optim_cls in OPTIMS:
    for dtype in DTYPES:
        torch._dynamo.reset()
        # torch._inductor.metrics.reset()
        input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
        )

        model(input).sum().abs().backward()
        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = torch.compile(opt_compiled.step)

        with torch.set_grad_enabled(False):
            start_time = time.time()
            compiled_step()
            summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")

        print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())

for s in summary:
    print(s)

CC: @janeyx99 @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110613

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit b123833 with merge base cf1b494 (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jon-chuang jon-chuang changed the title Jon chuang/fast multi tensor optim perf(inductor): use for loop with shortcut in Optimizers to speedup against list comprehensions (e.g. complex conversion) Oct 5, 2023
@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Oct 5, 2023

Actually, I know how to speed this up even further. We can add a has_complex flag, similar to has_sparse_grad.

See the use of has_sparse_grad to shortcut the any iterator in adagrad

This is currently being canary tested in #110607

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval contingent on green CI :)

]
device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]
for i in range(len(device_params)):
if torch.is_complex(device_params[i]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

for ((device_params, device_grads, device_state_sums, device_state_steps), _) in grouped_tensorlists.values():

device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh very good heuristical catch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can also do this for has_complex. Being tested in: #110607

Perhaps we should try to cover a smaller surface area with just the Adam changes (#110607) and then proceed with the large change here if things go smoothly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anw, has_complex changes are orthogonal, so I'm also down to just merge this one first and rebase #110607 and further has_complex improvements on this PR.

@janeyx99
Copy link
Contributor

janeyx99 commented Oct 5, 2023

Also, please land the Adagrad sparse fix separately

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Don't land before addressing my new comment

@jon-chuang
Copy link
Collaborator Author

@janeyx99 to be more explicit, this is how single_tensor case always shortcuts:

if torch.is_complex(param):

if torch.is_complex(param):

if torch.is_complex(param):

if torch.is_complex(param):

if is_complex_param:

if torch.is_complex(param):

@jon-chuang
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 5, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2023
pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2023
pytorchmergebot pushed a commit that referenced this pull request Oct 6, 2023
…_complex` shortcut (#110631)

Partial fix: #110606

More on `has_complex` shortcut: #110613 (comment)

CC: @janeyx99, @mlazos, @lezcano
Pull Request resolved: #110631
Approved by: https://github.com/lezcano
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[dynamo] Slow compile times for optimizers due to for loops

4 participants