KEMBAR78
[MPS][BE] Move FusedOptimizerOps to its own shader by malfet · Pull Request #141092 · pytorch/pytorch · GitHub
Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Nov 20, 2024

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141092

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Unrelated Failure

As of commit a1ba3fa with merge base 0443398 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Nov 20, 2024
@malfet malfet requested a review from Skylion007 November 20, 2024 02:23
@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2024
pytorchmergebot pushed a commit that referenced this pull request Nov 20, 2024
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions
Pull Request resolved: #141103
Approved by: https://github.com/kulinseth
ghstack dependencies: #141089, #141090, #141092
pytorchmergebot pushed a commit that referenced this pull request Nov 21, 2024
For MacOS14+

Running following script
```python
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: #141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2024
For MacOS14+

Running following script (adapted from one mentioned in #127242 )
```python
import torch
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools

def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
    fn(
        params,
        grads,
        exp_avgs,
        exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        foreach=False,
        capturable=False,
        fused=fused,
        amsgrad=amsgrad,
        beta1=0.9,
        beta2=0.99,
        lr=1e-3,
        weight_decay=.0,
        eps=1e-5,
        maximize=False,
        grad_scale=None,
        found_inf=None,
    )
    torch.mps.synchronize()

device, dtype = "mps", torch.bfloat16

results = []

for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]):
    print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
    params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
    max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else []
    state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)]
    fn = adamw.adamw if adamWflag else adam.adam

    for fused in [True, False]:

        t = benchmark.Timer(
                stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                label=f'Fused Adam on {device} using {dtype}',
                sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                globals=locals(),
                description= f"Fused: {fused}",
            ).blocked_autorange(min_run_time=5)
        results.append(t)

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: #141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions
Pull Request resolved: pytorch#141103
Approved by: https://github.com/kulinseth
ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
For MacOS14+

Running following script
```python
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: pytorch#141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
For MacOS14+

Running following script (adapted from one mentioned in pytorch#127242 )
```python
import torch
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools

def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
    fn(
        params,
        grads,
        exp_avgs,
        exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        foreach=False,
        capturable=False,
        fused=fused,
        amsgrad=amsgrad,
        beta1=0.9,
        beta2=0.99,
        lr=1e-3,
        weight_decay=.0,
        eps=1e-5,
        maximize=False,
        grad_scale=None,
        found_inf=None,
    )
    torch.mps.synchronize()

device, dtype = "mps", torch.bfloat16

results = []

for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]):
    print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
    params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
    max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else []
    state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)]
    fn = adamw.adamw if adamWflag else adam.adam

    for fused in [True, False]:

        t = benchmark.Timer(
                stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                label=f'Fused Adam on {device} using {dtype}',
                sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                globals=locals(),
                description= f"Fused: {fused}",
            ).blocked_autorange(min_run_time=5)
        results.append(t)

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: pytorch#141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
@github-actions github-actions bot deleted the gh/malfet/61/head branch December 22, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants