KEMBAR78
Fix slice scatter dtype consistency by ghostspiders · Pull Request #160851 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ghostspiders
Copy link
Contributor

@ghostspiders ghostspiders commented Aug 18, 2025

Fixes #147842
Fix torch.slice_scatter type inconsistency issue. I noticed previous PRs on this have stalled, so I'm opening this new PR.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160851

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 69bdfd0 with merge base 6737e2c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ghostspiders
Copy link
Contributor Author

@pytorchbot label "module: inductor"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 18, 2025
@ghostspiders
Copy link
Contributor Author

@eellison Thank you for your review. I have fixed the CI errors. Please let me know if any further modifications are needed, and kindly help run the CI checks again.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 20, 2025
@ghostspiders
Copy link
Contributor Author

@soulitzer Thank you for the review. I have gained a new understanding of this part. In Eager mode, the C++ implementation is directly invoked, while in Inductor mode, Python lowering performs type checks. I have now removed the assertion checks for Python lowering and fully delegated type handling to the C++ implementation to maintain logical consistency. The tests I conducted passed successfully.

@ghostspiders ghostspiders requested a review from soulitzer August 21, 2025 17:44
@soulitzer
Copy link
Contributor

Thanks for the update, that sounds good. Could you add a test?

@ghostspiders
Copy link
Contributor Author

@soulitzer Removing the assert x.get_dtype() == src.get_dtype() is not a good idea, so the solution has been reverted to the original approach: type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH. We have added tests, which have already passed.

@ghostspiders
Copy link
Contributor Author

ghostspiders commented Aug 22, 2025

@soulitzer ‌Testing type promotion
import torch

def test_slice_scatter_dtype_behavior():
    """Test slice_scatter in both Eager and Inductor modes"""

    print("=== Testing slice_scatter ===\n")

    # Test case
    target = torch.zeros(4, dtype=torch.float32)
    source = torch.tensor([1, 2], dtype=torch.int64)

    print(f"Target: {target.dtype}, Source: {source.dtype}")

    # Eager mode
    result_eager = target.slice_scatter(source, start=0, end=2)
    print(f"Eager: {result_eager.dtype}")

    # Inductor mode
    try:
        compiled_fn = torch.compile(lambda t, s: t.slice_scatter(s, start=0, end=2))
        result_inductor = compiled_fn(target, source)
        print(f"Inductor: {result_inductor.dtype}")

        # Compare
        dtype_match = result_eager.dtype == result_inductor.dtype
        value_match = torch.allclose(result_eager, result_inductor)

        print(f"Match: {'✅' if dtype_match and value_match else '❌'}")

    except Exception as e:
        print(f"Inductor failed: {e}")

if __name__ == "__main__":
    test_slice_scatter_dtype_behavior()
~                                           


=== Testing slice_scatter ===

Target: torch.float32, Source: torch.int64
Eager: torch.float32
Inductor: torch.float32
Match: ✅

@soulitzer
Copy link
Contributor

Removing the assert x.get_dtype() == src.get_dtype() is not a good idea

Why not? We don't have that assumption in eager?

ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH

Hmm, but my earlier point was that we don't actually do type promotion in this op.

@ghostspiders
Copy link
Contributor Author

ghostspiders commented Aug 22, 2025

@soulitzer
Removing the assert x.get_dtype() == src.get_dtype() ‌The test cases are failing

@ghostspiders
Copy link
Contributor Author

@soulitzer The ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH test passes. I suspect it actually performs non-trivial work beyond simple type promotion, although the detailed mechanics remain unclear at this stage.

@soulitzer
Copy link
Contributor

I suspect its not doing type promotion based on this example:

>>> a = torch.zeros(8, 8, dtype=torch.float32)
>>> b = torch.ones(2, 8, dtype=torch.float64)
>>> a.slice_scatter(b, start=6).dtype
torch.float32
>>>

Are you testing that?

@ghostspiders
Copy link
Contributor Author

I suspect its not doing type promotion based on this example:

>>> a = torch.zeros(8, 8, dtype=torch.float32)
>>> b = torch.ones(2, 8, dtype=torch.float64)
>>> a.slice_scatter(b, start=6).dtype
torch.float32
>>>

Are you testing that?

YES
(mtorch) gaoyufeng@DESKTOP-QT4IL9F:~/code/pytorch$ python
Python 3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

import torch
a = torch.zeros(8, 8, dtype=torch.float32)
b = torch.ones(2, 8, dtype=torch.float64)
a.slice_scatter(b, start=6).dtype
torch.float32

@soulitzer
Copy link
Contributor

soulitzer commented Aug 22, 2025

YES
(mtorch) gaoyufeng@DESKTOP-QT4IL9F:~/code/pytorch$ python
Python 3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

import torch
a = torch.zeros(8, 8, dtype=torch.float32)
b = torch.ones(2, 8, dtype=torch.float64)
a.slice_scatter(b, start=6).dtype
torch.float32

Sorry to clarify, what I mean to ask is if you have checked parity for this example between the inductor lowering and the op? I'm asking because I don't see it in the test case "test_slice_scatter_dtype_consistency", could you add it there?

@ghostspiders
Copy link
Contributor Author

@soulitzer Thank you for your review. I apologize, as my previous test was incorrect. Removing the assertion x.get_dtype() == src.get_dtype() was the right decision. I have updated the test code accordingly. Thanks again for the review.

return torch.slice_scatter(y, x, 0)

for dtype in [
torch.int8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but probably no need to test every single int type
one int, float64, and then complex is probably sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the test

@soulitzer
Copy link
Contributor

Thanks for the update. For the test failures, let's just skip in test/inductor/test_torchinductor_codegen_dynamic_shapes.py

@ghostspiders
Copy link
Contributor Author

@soulitzer Thanks for the review.. Can we merge it?

@ghostspiders ghostspiders requested a review from soulitzer August 25, 2025 15:59
@soulitzer
Copy link
Contributor

We may have to fix a test failure to pass CI. See my previous comment here #160851 (comment)

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 27, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'marge' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick} ...

Try @pytorchbot --help for more info.


@register_lowering(aten.slice_scatter, type_promotion_kind=None)
def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
assert x.get_dtype() == src.get_dtype()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should replace the assertion with:

src = to_dtype(src, x.get_dtype())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All set, I've updated it.

@ghostspiders ghostspiders requested a review from soulitzer August 29, 2025 15:22
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix-slice-scatter-dtype-consistency onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix-slice-scatter-dtype-consistency && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the fix-slice-scatter-dtype-consistency branch from da3ee2e to 69bdfd0 Compare September 1, 2025 22:20
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Sep 1, 2025
@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 1, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fixes pytorch#147842
Fix torch.slice_scatter type inconsistency issue. I noticed previous PRs on this have stalled, so I'm opening this new PR.

Pull Request resolved: pytorch#160851
Approved by: https://github.com/soulitzer
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Fixes pytorch#147842
Fix torch.slice_scatter type inconsistency issue. I noticed previous PRs on this have stalled, so I'm opening this new PR.

Pull Request resolved: pytorch#160851
Approved by: https://github.com/soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[inductor] torch.slice_scatter throws AssertionError when meeting internal float32

5 participants