KEMBAR78
Add warning about removed sm50 and sm60 arches by atalman · Pull Request #158301 · pytorch/pytorch · GitHub
Skip to content

Conversation

@atalman
Copy link
Contributor

@atalman atalman commented Jul 15, 2025

Related to #157517

Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures.
We would like to include reference to the issue: #157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures.

Test:

>>> torch.cuda.get_arch_list()
['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']
>>> torch.cuda.init()
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning: 
    Found <GPU Name> which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 7.0.

  warnings.warn(
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning: 
                        Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds.
                        Please see https://github.com/pytorch/pytorch/issues/157517
                        Please install CUDA 12.6 builds if you require Maxwell or Pascal support.

cc @ptrblck @msaroufim @eqy @jerryzh168 @albanD @malfet

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158301

Note: Links to docs will display an error until the docs builds have been completed.

❌ 61 Cancelled Jobs, 1 Unrelated Failure

As of commit 100002a with merge base 0879921 (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@atalman atalman added topic: not user facing topic category module: cuda Related to torch.cuda, and CUDA support in general labels Jul 15, 2025
@atalman atalman requested review from eqy and syed-ahmed as code owners July 15, 2025 13:43
)
if current_arch < min_arch:
warnings.warn(
old_gpu_warn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like incorrect_binary_warn is never used. However its probably more accurate warning

Copy link
Collaborator

@nWEIdia nWEIdia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just had one small suggestion.

@atalman atalman requested review from albanD and malfet July 16, 2025 10:36
@albanD
Copy link
Collaborator

albanD commented Jul 16, 2025

Wait, these warning are saying two opposite things lol.
One says that it is not supported and one says you just need to install another binary.

Can we rationalize these messages to be more aligned with the state of the world:

  • What is the real lowest supported version by ANY binary (this needs to be hardcoded). Warn based on that.
  • What is the lowest/newest supported version for THIS binary, warn and suggest an appropriate binary for that. Both down (for old arch) and up (for newer arch)

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@atalman
Copy link
Contributor Author

atalman commented Jul 16, 2025

@pytorchmergebot merge -f "lint is green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@nWEIdia
Copy link
Collaborator

nWEIdia commented Jul 16, 2025

Just adding a note that in future, we might want to re-evaluate for "cur_arch > max_arch" case, as there could be scenarios that the binary may be able to still support. But "cur_arch < min_arch" is definitely not supported.

e.g. suppose we build sm up to sm80, running on sm86 would still work. Similarly along this line for sm120.

@atalman
Copy link
Contributor Author

atalman commented Jul 16, 2025

@pytorchbot cherry-pick --onto release/2.8 -c critical

pytorchbot pushed a commit that referenced this pull request Jul 16, 2025
Related to #157517

Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures.
We would like to include reference to the issue: #157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures.

Test:
```
>>> torch.cuda.get_arch_list()
['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']
>>> torch.cuda.init()
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning:
    Found <GPU Name> which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 7.0.

  warnings.warn(
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning:
                        Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds.
                        Please see #157517
                        Please install CUDA 12.6 builds if you require Maxwell or Pascal support.
```

Pull Request resolved: #158301
Approved by: https://github.com/nWEIdia, https://github.com/albanD

(cherry picked from commit fb731fe)
@pytorchbot
Copy link
Collaborator

Cherry picking #158301

The cherry pick PR is at #158478 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

atalman added a commit that referenced this pull request Jul 16, 2025
Add warning about removed sm50 and sm60 arches (#158301)

Related to #157517

Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures.
We would like to include reference to the issue: #157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures.

Test:
```
>>> torch.cuda.get_arch_list()
['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']
>>> torch.cuda.init()
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning:
    Found <GPU Name> which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 7.0.

  warnings.warn(
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning:
                        Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds.
                        Please see #157517
                        Please install CUDA 12.6 builds if you require Maxwell or Pascal support.
```

Pull Request resolved: #158301
Approved by: https://github.com/nWEIdia, https://github.com/albanD

(cherry picked from commit fb731fe)

Co-authored-by: atalman <atalman@fb.com>
@facebook-github-bot
Copy link
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Jul 19, 2025
This reverts commit fb731fe.

Reverted #158301 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#158301 (comment)))
@pytorchmergebot
Copy link
Collaborator

@atalman your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Jul 19, 2025
@atalman atalman force-pushed the add_warning_about_old_sm branch from e17ea11 to 7706737 Compare July 19, 2025 00:39
Move code

fixes

Revert "conda"

This reverts commit 2853662.

Revert "use tos accept"

This reverts commit 8b34264.

Revert "conda"

This reverts commit 2853662.

Revert "Revert "conda""

This reverts commit e732654.

Revert "Revert "use tos accept""

This reverts commit c456c54.

Revert "Revert "conda""

This reverts commit bb4fa09.

fix

fix_arch_list

fix

fixes
@atalman atalman force-pushed the add_warning_about_old_sm branch from e8cd442 to 100002a Compare July 19, 2025 00:49
if torch.version.cuda is not None: # on ROCm we don't want this check
CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841
if (
torch.version.cuda is not None and torch.cuda.get_arch_list()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New version added check for torch.cuda.get_arch_list()

@atalman atalman closed this Jul 19, 2025
pytorchmergebot pushed a commit that referenced this pull request Jul 20, 2025
Related to #157517

Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures.
We would like to include reference to the issue: #157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures.

Test:
```
>>> torch.cuda.get_arch_list()
['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']
>>> torch.cuda.init()
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning:
    Found <GPU Name> which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 7.0.

  warnings.warn(
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning:
                        Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds.
                        Please see #157517
                        Please install CUDA 12.6 builds if you require Maxwell or Pascal support.
```

Please note I reverted original PR #158301 because it broke internal users. This is a reland, added added check for non empty torch.cuda.get_arch_list()
Pull Request resolved: #158700
Approved by: https://github.com/huydhn, https://github.com/Skylion007, https://github.com/eqy
tvukovic-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 20, 2025
Add warning about removed sm50 and sm60 arches (pytorch#158301)

Related to pytorch#157517

Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures.
We would like to include reference to the issue: pytorch#157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures.

Test:
```
>>> torch.cuda.get_arch_list()
['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120']
>>> torch.cuda.init()
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning:
    Found <GPU Name> which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 7.0.

  warnings.warn(
/home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning:
                        Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds.
                        Please see pytorch#157517
                        Please install CUDA 12.6 builds if you require Maxwell or Pascal support.
```

Pull Request resolved: pytorch#158301
Approved by: https://github.com/nWEIdia, https://github.com/albanD

(cherry picked from commit fb731fe)

Co-authored-by: atalman <atalman@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR Merged module: cuda Related to torch.cuda, and CUDA support in general Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants