KEMBAR78
Adding lowering to persistent-tma device kernel for _scaled_mm by drisspg · Pull Request #142045 · pytorch/pytorch · GitHub
Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Dec 4, 2024

Stack from ghstack (oldest at bottom):

Summary

This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices.

Limitations:

  • This implementations does not work with Bias values:
    # Inductor does not allow optional tensor input arguments currently (pass None as an
    Plan is to remove this work around and enforce that both scaling + bias is properly done as epilogues onto the existing templates
  • K dim must be 32 or greater for these to take effect
  • Gated by a config flag ( currently defaults to Off, maybe should be on)

Testing

We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green:
Screenshot 2024-12-05 at 7 24 07 PM

Follow Ups

  • Work to update the base mm triton templates and utilize the same template from mm/addmm/scaled_mm w/ respective epilogues
  • Tuning on Persistent kernel configs. I found ones that work for my problem shapes but need to do some more NCU work

Some profiling code I was using

Code I am using to iterate w/

import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path

from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
    addmm_float8_unwrapped_inference,
    preprocess_data,
    Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
    matmul_persistent,
    matmul_tma_persistent,
    matmul_device_tma_persistent,
)
from enum import Enum

logging.getLogger("transformer_nuggets").setLevel(logging.INFO)


class FP8Kernel(Enum):
    PERSISTENT = "Persistent"
    PERSISTENT_TMA = "Persistent-TMA"
    DEVICE_TMA = "Device-TMA"
    SCALED_MM = "Scaled-MM"


class ScalingStrategy(Enum):
    PER_TENSOR = "PerTensor"
    PER_ROW = "PerRow"


@dataclass(frozen=True)
class ExperimentConfig:
    M: int
    K: int
    N: int
    scaling_strategy: ScalingStrategy
    fp8_kernel: FP8Kernel
    compile: bool


def get_fp8_matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    scaling_strategy: ScalingStrategy,
    fp8_kernel: FP8Kernel,
):
    A_fp8 = A.to(torch.float8_e4m3fn)
    B_fp8 = B.to(torch.float8_e4m3fn)
    A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
    
    if scaling_strategy == ScalingStrategy.PER_TENSOR:
        a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    elif scaling_strategy == ScalingStrategy.PER_ROW:
        a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
        b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
    else:
        raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

    assert fp8_kernel == FP8Kernel.SCALED_MM
    return lambda: addmm_float8_unwrapped_inference(
        A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
    )


def run_matmul(config: ExperimentConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
    B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

    fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)


    if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
        fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")


    _ = fp8_matmul()

    return


def main():
    torch.random.manual_seed(123)

    # Define your experiment configuration here
    config = ExperimentConfig(
        M=8192,
        K=8192,
        N=8192,
        scaling_strategy=ScalingStrategy.PER_TENSOR,
        fp8_kernel=FP8Kernel.SCALED_MM,
        compile=True,
    )

    run_matmul(config)


if __name__ == "__main__":
    CLI(main)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142045

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit da679d2 with merge base 9dffd12 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@aakhundov aakhundov added the topic: not user facing topic category label Dec 5, 2024
@eellison eellison self-requested a review December 5, 2024 19:26
[ghstack-poisoned]
@drisspg drisspg changed the title [Prototype] Adding lowering to persistent-tma device kernel for _scaled_mm Adding lowering to persistent-tma device kernel for _scaled_mm Dec 6, 2024
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to specify this as an epilogue or reuse the tma lowering in future pr

"""Defines the grid for persistent kernels."""
return (
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),
1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the extra 1s here, just for my own knowledge ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this has to do w/ how we thread the launch config to triton, without it I get

  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 1079, in run
    return launcher(
           ^^^^^^^^^
  File "<string>", line 5, in launcher
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
ValueError: not enough values to unpack (expected 3, got 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe, the grid is assumed to be a 3-tuple (or a callable returning one) in Triton grid launches.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@drisspg
Copy link
Contributor Author

drisspg commented Dec 8, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 8, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…ch#142045)

# Summary
This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices.

Limitations:
* This implementations does not work with Bias values: https://github.com/pytorch/pytorch/blob/0602676c8df2d1f85b28a16ec650fbfa844145ce/torch/_inductor/kernel/mm_scaled.py#L106 Plan is to remove this work around and enforce that both scaling + bias is properly done as epilogues onto the existing templates
* K dim must be 32 or greater for these to take effect
* Gated by a config flag ( currently defaults to Off, maybe should be on)

## Testing
We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green:
<img width="1680" alt="Screenshot 2024-12-05 at 7 24 07 PM" src="https://github.com/user-attachments/assets/9c520541-d97a-416f-9af7-e68b366ec90f">

## Follow Ups
* Work to update the base mm triton templates and utilize the same template from mm/addmm/scaled_mm w/ respective epilogues
* Tuning on Persistent kernel configs. I found ones that work for my problem shapes but need to do some more NCU work

### Some profiling code I was using

Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path

from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
    addmm_float8_unwrapped_inference,
    preprocess_data,
    Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
    matmul_persistent,
    matmul_tma_persistent,
    matmul_device_tma_persistent,
)
from enum import Enum

logging.getLogger("transformer_nuggets").setLevel(logging.INFO)

class FP8Kernel(Enum):
    PERSISTENT = "Persistent"
    PERSISTENT_TMA = "Persistent-TMA"
    DEVICE_TMA = "Device-TMA"
    SCALED_MM = "Scaled-MM"

class ScalingStrategy(Enum):
    PER_TENSOR = "PerTensor"
    PER_ROW = "PerRow"

@DataClass(frozen=True)
class ExperimentConfig:
    M: int
    K: int
    N: int
    scaling_strategy: ScalingStrategy
    fp8_kernel: FP8Kernel
    compile: bool

def get_fp8_matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    scaling_strategy: ScalingStrategy,
    fp8_kernel: FP8Kernel,
):
    A_fp8 = A.to(torch.float8_e4m3fn)
    B_fp8 = B.to(torch.float8_e4m3fn)
    A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))

    if scaling_strategy == ScalingStrategy.PER_TENSOR:
        a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    elif scaling_strategy == ScalingStrategy.PER_ROW:
        a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
        b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
    else:
        raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

    assert fp8_kernel == FP8Kernel.SCALED_MM
    return lambda: addmm_float8_unwrapped_inference(
        A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
    )

def run_matmul(config: ExperimentConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
    B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

    fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)

    if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
        fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")

    _ = fp8_matmul()

    return

def main():
    torch.random.manual_seed(123)

    # Define your experiment configuration here
    config = ExperimentConfig(
        M=8192,
        K=8192,
        N=8192,
        scaling_strategy=ScalingStrategy.PER_TENSOR,
        fp8_kernel=FP8Kernel.SCALED_MM,
        compile=True,
    )

    run_matmul(config)

if __name__ == "__main__":
    CLI(main)
```

Pull Request resolved: pytorch#142045
Approved by: https://github.com/eellison
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
pytorchmergebot pushed a commit that referenced this pull request Dec 16, 2024
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper):

1. The min. hardware and Triton version requirements are met for the TMA support.

2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).

3. The `config.triton.enable_persistent_tma_matmul` is set to `True`.

Additional notes:

1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.

2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.

3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).

4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above.

5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.

6. This PR is rebased onto and unifies with two related PRs landed previously: #142045 (some infra unification with the persistent+TMA template for _scaled_mm) and #134532 (add possibility to disable prolog fusion for selected choices).

7. The current Triton TMA API only supports 1D and 2D descriptors (even after triton-lang/triton#5290, see [here](https://github.com/triton-lang/triton/blob/9829ce87ccb333a2b264b3a80b39a534bfa865ac/python/triton/language/core.py#L1957)). For now, this blocks adding persistent+TMA template for `torch.bmm`.

Pull Request resolved: #142101
Approved by: https://github.com/drisspg, https://github.com/eellison
@github-actions github-actions bot deleted the gh/drisspg/93/head branch January 8, 2025 02:04
)

persistent_mm_kernel_configs = [
{"config": (128, 128, 64, 3, 8), "cond": True},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats the 4th arg? one taking value of 3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants