KEMBAR78
[sparse] Add fast semi-structured spasification kernels by jcaip · Pull Request #122350 · pytorch/pytorch · GitHub
Skip to content

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Mar 20, 2024

Stack from ghstack (oldest at bottom):

This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

  • torch._sparse_semi_structured_tile

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

  • torch._sparse_semi_structured_apply

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

  • torch._sparse_semi_structured_apply_dense

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
prune_dense_static_sort
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
sparse_semi_structured_tile
compute_compressed_swizzled_bitmask

cc @albanD

Differential Revision: D56190801

This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added to the `torch.sparse` namespace.

In particular, three new functions have been added:

* `torch.sparse._semi_structured_sparsify_both_ways`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch.sparse._semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122350

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b3898ad with merge base f433517 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Mar 20, 2024
jcaip added a commit that referenced this pull request Mar 20, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added to the `torch.sparse` namespace.

In particular, three new functions have been added:

* `torch.sparse._semi_structured_sparsify_both_ways`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch.sparse._semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

ghstack-source-id: 6a714de
Pull Request resolved: #122350
@jcaip jcaip changed the title [sparse][be] Add fast semi-structured spasification kernels [sparse][semi-structured] Add fast semi-structured sparsification kernels Mar 20, 2024
@jcaip jcaip requested review from alexsamardzic and cpuhrsch March 20, 2024 22:42
#include <torch/library.h>
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(sparse, m) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is something we typically land in aten. We'd probalby want to make these private natiev functions.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the use of TORCH_LIBRARY_IMPL I think this is fine

@alexsamardzic
Copy link
Collaborator

Will give it a more detailed look, but just a pair of questions at the moment: Is this PR also replacing what is in torch.sparse._semi_structrured_conversions module? If so, shall we remove this module (and, do we have some benchmarking results vs. compiled version of functions in this module)?

@jcaip
Copy link
Contributor Author

jcaip commented Mar 21, 2024

@alexsamardzic

No, this shouldn't replace what's in torch.sparse._semi_structrured_conversions. I think ideally we would want to express this logic in pure python so it could be torch.compiled like in those functions. However, these kernels are slightly different because I think they apply 2:4 sparsity over a 4x4 tile instead of 1x4 strip.

I didn't have the time to look into this more, so just landing these as is for now and we can revisit as needed. Would definitely be interested if you had ideas here though.

@alexsamardzic
Copy link
Collaborator

Ok, looked more carefully into it - so this is just about importing this functionality from xformers, right?

I've benchmarked it, for CUTLASS back-end, against the compiled version of existing conversion code (will denote variant with compiled version of existing code "Triton", and the new variant "C++/CUDA")

Benchmarking script
#! /usr/bin/env python

import random

import torch
from torch.testing import make_tensor
import torch.utils.benchmark as benchmark

from torch.sparse.semi_structured import (
    SparseSemiStructuredTensor,
    to_sparse_semi_structured,
)

SEMI_STRUCTURED_SUPPORTED_DTYPES = [
    torch.float16,
    torch.bfloat16,
    ###torch.float32,
    ###torch.int8,
]


def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
    pattern = "2by4" if dtype != torch.float32 else "1by2"
    if pattern == "1by2":
        ksparse = 2
        choices = [[0, 1], [1, 0]]
    elif pattern == "2by4":
        ksparse = 4
        choices = [
            [1, 1, 0, 0],
            [1, 0, 1, 0],
            [1, 0, 0, 1],
            [0, 1, 1, 0],
            [0, 1, 0, 1],
            [0, 0, 1, 1],
        ]
    mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
    mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
    dense = make_tensor(r, c, dtype=dtype, device=device)
    dense[dense == 0] = 1  # To prevent zeros except where mask applied.
    dense = dense.masked_fill(~mask, 0)
    return dense


def run_benchmark(m, k, device, dtype, results):
    label = f"Triton vs. C++/CUDA ({dtype})"
    sub_label = f"m:{m:5d} | k:{k:5d}"

    x = rand_sparse_semi_structured(m, k, dtype, device)

    to_sparse_semi_structured_compiled = torch.compile(
        lambda x: to_sparse_semi_structured(x)
    )

    SparseSemiStructuredTensor._FORCE_CUTLASS = True
    xsp_1 = to_sparse_semi_structured_compiled(x)
    xsp_2 = SparseSemiStructuredTensor.from_dense_fast(x)

    measurement = benchmark.Timer(
        stmt="to_sparse_semi_structured(x)",
        globals={
            "to_sparse_semi_structured": to_sparse_semi_structured_compiled,
            "x": x,
        },
        label=label,
        sub_label=sub_label,
        description="Triton",
    ).blocked_autorange()
    results.append(measurement)

    measurement = benchmark.Timer(
        stmt="to_sparse_semi_structured(x)",
        globals={
            "to_sparse_semi_structured": SparseSemiStructuredTensor.from_dense_fast,
            "x": x,
        },
        label=label,
        sub_label=sub_label,
        description="C++/CUDA",
    ).blocked_autorange()
    results.append(measurement)


if __name__ == "__main__":
    device = "cuda"

    shapes = [
        # distilbert shapes
        (768, 3072, 768),
        (3072, 768, 3072),
        # jiecao shapes
        (1024, 1536, 2048),
        (1024, 9408, 2048),
        (1024, 3200, 2048),
        (1024, 256, 9472),
        (1024, 10240, 256),
        (1024, 256, 12608),
        (1024, 2560, 1024),
        (1024, 512, 10240),
        (1024, 10240, 512),
        (1024, 2048, 1024),
        (1024, 512, 512),
        (1024, 1024, 1024),
        (1024, 2048, 2048),
        (2048, 1536, 2048),
        (2048, 9408, 2048),
        (2048, 3200, 2048),
        (2048, 256, 9472),
        (2048, 10240, 256),
        (2048, 256, 12608),
        (2048, 2560, 1024),
        (2048, 512, 10240),
        (2048, 10240, 512),
        (2048, 2048, 1024),
        (2048, 512, 512),
        (2048, 1024, 1024),
        (2048, 2048, 2048),
    ]

    torch.set_printoptions(
        precision=3,
        threshold=None,
        edgeitems=4,
        linewidth=460,
        profile=None,
        sci_mode=False,
    )

    for dtype in SEMI_STRUCTURED_SUPPORTED_DTYPES:
        results = []

        for m, k, _ in shapes:
            try:
                print(
                    f"m = {m:5d}, k = {k:5d}, dtype = {dtype} ... ",
                    end="",
                )
                run_benchmark(m, k, device, dtype, results)
                print("ok")
            except Exception:
                print("failed")
                continue

        compare = benchmark.Compare(results)
        compare.colorize(rowwise=True)
        compare.print()

"C++/CUDA" variant is indeed reported as significantly faster than "Triton" variant, still each one has its own caveats:

  • "C++/CUDA" variant supports float16 and bfloat16, while "Triton" variant supports int8 and float32 too.
  • "Triton" variant will reuse code compiled for the first shape encountered. So the real performance could be actually much better than shown by the script above (but still slower than "C++/CUDA" variant, for the same shape), in case if to_sparse_semi_structured re-compiled for each individual shape. That's something to keep in mind in case this C++/CUDA code indeed ported to to-be-compiled PyTorch code.
  • Alike, as mentioned in "to_sparse_semi_structured" cannot save memory on A100 using CUTLASS #115008, when "Triton" variant not compiled it uses excessive amount of GPU memory. This is because it is pretty much written to be compiled, so in order to avoid graph breaks it uses lots of auxiliary tensors. When compiled, it all gets fused, so it's OK, but if used as is then all of these auxiliary tensors are materialized, so memory usage is huge. On the other side, PyTorch has issues (or at least it had it when my conversion code merged) with including compiled code. So that's also something to keep in mind if porting this new C++/CUDA code to PyTorch.

As far as CUTLASS side of things concerned, my opinion is that it should provide fast conversion (at least from dense to sparse semi-structured) routine itself. It actually provides code for the most complicated part of the conversion, but it's CPU code, and implemented as an utility function (i.e. not directly exposed through CUTLASS API). Still, most of the functionality needed for this conversion is written as host/device functions, so it should be possible to write corresponding GPU code in a reasonable amount of time. If we think it would be worthwhile, I can do it at some point, and compare to this xformers code. On the other side, these issues with compiled code mentioned above would still hold.

…ication kernels"

This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added to the `torch.sparse` namespace.

In particular, three new functions have been added:

* `torch.sparse._semi_structured_sparsify_both_ways`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch.sparse._semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
@jcaip jcaip changed the title [sparse][semi-structured] Add fast semi-structured sparsification kernels [sparse][be] Add fast semi-structured spasification kernels Mar 27, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Mar 27, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

ghstack-source-id: 020c113
Pull Request resolved: #122350
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch.sparse._semi_structured_apply_dense_output`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new `from_dense_fast`
classmethod to create sparse tensors with this format.

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Apr 17, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

ghstack-source-id: bc59e3f
Pull Request resolved: #122350
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

cc albanD

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Apr 18, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

ghstack-source-id: 16c7db2
Pull Request resolved: #122350
@jcaip
Copy link
Contributor Author

jcaip commented Apr 19, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Pull Request resolved: pytorch#122350
Approved by: https://github.com/cpuhrsch
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…rch#122350)"

This reverts commit c63a7b5.

Reverted pytorch#122350 on behalf of https://github.com/malfet due to This broke rocm builds, which is visible on PR as well ([comment](pytorch#122350 (comment)))
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Pull Request resolved: pytorch#122350
Approved by: https://github.com/cpuhrsch
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: pytorch#122350
Approved by: https://github.com/cpuhrsch
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: #122350
Approved by: https://github.com/cpuhrsch
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Pull Request resolved: pytorch#122350
Approved by: https://github.com/cpuhrsch
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: pytorch#122350
Approved by: https://github.com/cpuhrsch
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: #122350
Approved by: https://github.com/cpuhrsch
@github-actions github-actions bot deleted the gh/jcaip/64/head branch June 1, 2024 02:02
jcaip added a commit to pytorch/ao that referenced this pull request Jun 6, 2024
This PR adds in support for training acceleration, using runtime semi-structured sparsity kernels, which landed in core earlier: pytorch/pytorch#122350

This collects the necessary autograd functions, to support training and packages it up in a replacement `nn.Linear` modules, `SemiSparseLinear`, as well as a user API to swap out modules, `swap_linear_with_semi_sparse_linear_`. 

It also adds in some benchmarking code from xformers in order to measure the speedup of this module when applied to DINO shapes. 

We have a blog post coming out with more details about how this works. 

Testing:
```
python test/sparsity/test_fast_sparse_training.py 
```

Benchmarking:
```
python benchmarks/benchmark_semi_sparse.py 
```

For VIT-L MLP shapes we see the following results:
```
[------------------------------------------------ mlpfwbw -------------------------------------------------]
                                  |   act24   |   dense   |   w24    |  s24_inp_sparsify24  |  s24_inp_clone
1 threads: -------------------------------------------------------------------------------------------------
      f16 (44160,1024,4096,1024)  |  11881.0  |  11534.3  |  9204.7  |        255.1         |      125.8

Times are in microseconds (us).
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: sparse release notes category Reverted skip-pr-sanity-checks

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants