KEMBAR78
[sparse] semi-structured sparse + torch.compile support by jcaip · Pull Request #111049 · pytorch/pytorch · GitHub
Skip to content

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Oct 11, 2023

Stack from ghstack (oldest at bottom):

Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:

python test/test_sparse_semi_structured.py -k compile

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111049

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 575d075 with merge base 4b324a8 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Oct 11, 2023
jcaip added a commit that referenced this pull request Oct 11, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 56dd0a8
Pull Request resolved: #111049
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Oct 11, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 999ce27
Pull Request resolved: #111049
@jcaip
Copy link
Contributor Author

jcaip commented Oct 12, 2023

So I think we will need some additional work before subclass tracing and 2:4 sparsity can work together. Loosely I think the problem is the op level at which inductor breaks stuff down to call the subclass is more decomposed than in eager (which we have support for).

For example - fp16 addmm in eager is a simple swap with torch._cslt_sparse_mm, but in the compiled version it looks like it's trying to break down addmm into smaller pieces, including casts to fp32 for computation type. See here for a trace: https://www.internalfb.com/phabricator/paste/view/P851580855

It looks like when compiling, inductor is trying different combinations of ops to represent aten.linear

transpose -> addmm
permute(0,1) -> addmm
reinterpret_tensor(flip strides) -> addmm.out (instead of addmm.default)

When I try a linear -> contiguous -> relu I also see it trying to run mm and mm.out.

So in order to support this, we'll either need to

  1. Restrict the decompositions/codegen somehow? Not sure exactly what the issue here so will need to dig deeper. It's not clear to me what part of the codebase is trying different combinations.
  2. Add op support for all of the possible generated ops. This is probably something we should do anyways, as some of these combinations may be faster. For example, adding support for passing in the output matrix to sparse matmul.

permute(0,1) and reinterpret_tensor should be mapped to transpose.

fp32 casts can be no-ops.

Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Oct 12, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 743b278
Pull Request resolved: #111049
transposed=not args[0].transposed,
)

if func is torch.ops.prims.convert_element_type.default:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh These are the ops I added to get this test to "work", you can comment these out to see where it fails earlier in the chain.

Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Oct 13, 2023
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7183856
Pull Request resolved: #111049
Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

Placeholder PR for subclassing + 2:4 sparsity

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@jcaip jcaip changed the title [wip] semi-structured sparse + torch.compile support [sparse] semi-structured sparse + torch.compile support Oct 13, 2023
jcaip added a commit that referenced this pull request Oct 13, 2023
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: eeba936
Pull Request resolved: #111049
@jcaip jcaip reopened this Oct 18, 2023
@jcaip
Copy link
Contributor Author

jcaip commented Oct 18, 2023

This is kind of strange.

I was able to reproduce this error running locally, but I don't think it's related to my changes here.

Running the following command yields failures for the torch.compile test, note that the other cusparselt/cutlass tests do not report a memory leak.

PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py

............................................................................................
======================================================================
ERROR: test_conversions_all_patterns_backend_cutlass_cuda_float16 (__main__.TestSparseSemiStructuredCUDA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_conversions_all_patterns_backend_cutlass_cuda_float16! Caching allocator allocated memory was 1536 and is now reported as 2560 on device 0. CUDA driver allocated memory was 1044054016 and is now 1046151168.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_conversions_all_patterns_backend_cutlass_cuda_float16

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(1, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(1, 128)_cuda! Caching allocator allocated memory was 2048 and is now reported as 53760 on device 0. CUDA driver allocated memory was 1438318592 and is now 1440415744.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(1, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 168 tests in 110.646s

FAILED (errors=2)

Then I tried commenting out the call to to_sparse_semi_structured.py, which means we are just compare torch.compile vs eager mode. But I still see the compile tests fail. strangely I see more failures?

............................................................................................
======================================================================
ERROR: test_conversions_all_patterns_backend_cutlass_cuda_float16 (__main__.TestSparseSemiStructuredCUDA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_conversions_all_patterns_backend_cutlass_cuda_float16! Caching allocator allocated memory was 1536 and is now reported as 2560 on device 0. CUDA driver allocated memory was 1044054016 and is now 1046151168.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_conversions_all_patterns_backend_cutlass_cuda_float16

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(1, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(1, 128)_cuda! Caching allocator allocated memory was 2048 and is now reported as 35328 on device 0. CUDA driver allocated memory was 1438318592 and is now 1440415744.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(1, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(128, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(128, 128)_cuda! Caching allocator allocated memory was 35328 and is now reported as 68608 on device 0. CUDA driver allocated memory was 1440415744 and is now 1442512896.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(128, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(64, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(64, 128)_cuda! Caching allocator allocated memory was 68608 and is now reported as 101888 on device 0. CUDA driver allocated memory was 1442512896 and is now 1444610048.
                                                                                                                                                                                                                                                                                          [0/1936]
To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(64, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(64, 128, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(64, 128, 128)_cuda! Caching allocator allocated memory was 101888 and is now reported as 135168 on device 0. CUDA driver allocated
memory was 1444610048 and is now 1446707200.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cusparselt_dense_input_shape_(64, 128, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(128, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(128, 128)_cuda! Caching allocator allocated memory was 167424 and is now reported as 200704 on device 0. CUDA driver allocated memory
was 1446707200 and is now 1448804352.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(128, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(64, 128)_cuda (__main__.TestSparseSemiStructuredCUDA)
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2453, in wrapper
    method(*args, **kwargs)
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 2452, in wrapper
    with policy():
  File "/home/jessecai/local/A/pytorch/torch/testing/_internal/common_utils.py", line 1902, in __exit__
    raise RuntimeError(msg)
RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(64, 128)_cuda! Caching allocator allocated memory was 200704 and is now reported as 233984 on device 0. CUDA driver allocated memory w
as 1448804352 and is now 1450901504.

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/test_sparse_semi_structured.py -k test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(64, 128)_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 168 tests in 119.071s

FAILED (errors=7)

cc @clee2000 @bdhirsh do y'all have any ideas what might be causing this? Could it be something funky with the tooling / torch.compile?

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 20, 2023

Hey @jcaip - I messed around with the test locally, and this seems like a memory leak directly in cutlass / the SparseSemiStructuredTensor subclass. Here's a minimal repro:

import torch
import random
from torch.sparse.semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured

SparseSemiStructuredTensor._FORCE_CUTLASS = True
def f():
    mask_entries = [random.choice([[0, 1], [1, 0]]) for i in range(16384)]
    A = torch.tensor(mask_entries, dtype=torch.float16, device='cuda').reshape(128, 256).contiguous()
    A_sparse = to_sparse_semi_structured(A)

print(torch.cuda.memory_allocated())
f()
print(torch.cuda.memory_allocated())

prints:

0
512

@cpuhrsch
Copy link
Contributor

@alexsamardzic - Alek can you take a look into this as well?

@alexsamardzic
Copy link
Collaborator

This is sort of known problem; namely, the code doing conversion is actually @torch.compiled-d PyTorch code, and seemingly the memory leak appears because of this. Namely, the leak will disappear if in torch/sparse/_semi_structured_conversions.py, line:

def sparse_semi_structured_from_dense_cutlass(dense, compile=True):

changed to:

def sparse_semi_structured_from_dense_cutlass(dense, compile=False):

While trying to change the function in question to come up with minimum self-contained example, I came up with something even unrelated:

import torch

@torch.compile
def foo(A):
    return A

def f():
    A = torch.ones(128, dtype=torch.float16, device='cuda')
    B = foo(A)

print(torch.cuda.memory_allocated())
f()
print(torch.cuda.memory_allocated())

It will also print 512 instead of 0.

@lezcano @peterbell10

@jcaip
Copy link
Contributor Author

jcaip commented Oct 23, 2023

@alexsamardzic What do you think about turning off this flag to false now that we have torch.compile support for subclasses? I think this can also cause issues if we call to torch.compile again later.

@alexsamardzic
Copy link
Collaborator

That's fine with me.

Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@jcaip
Copy link
Contributor Author

jcaip commented Oct 23, 2023

Ah I believe the test is failing because I was not using torch._dynamo.test_case.TestCase for my tests.

Writing a new test for semi-structured sparse should fix this.

Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
jcaip added a commit that referenced this pull request Oct 23, 2023
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8e127cc
Pull Request resolved: #111049
@jcaip
Copy link
Contributor Author

jcaip commented Oct 24, 2023

@pytorchbot merge -f "passing ciflow/slow now and unrelated failing test"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/jcaip/46/head branch October 27, 2023 14:25
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#111049
Approved by: https://github.com/cpuhrsch
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#111049
Approved by: https://github.com/cpuhrsch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/slow ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: sparse release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants