KEMBAR78
[fx][pass] Support converting a float32 tensor to a scalar in FX trace. by thenumberouscode · Pull Request #158216 · pytorch/pytorch · GitHub
Skip to content

Conversation

@thenumberouscode
Copy link
Contributor

@thenumberouscode thenumberouscode commented Jul 14, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158216

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 4d89735 with merge base f636736 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: dynamo release notes: fx release notes category labels Jul 14, 2025
@albanD albanD requested a review from anijain2305 July 14, 2025 21:45
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 14, 2025
@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 15, 2025

@anijain2305 Could you please take a look at my PR? Thank you

@thenumberouscode
Copy link
Contributor Author

@anijain2305 Please review the PR at your convenience.

@ezyang
Copy link
Contributor

ezyang commented Jul 16, 2025

@laithsakka please coordinate with #158385

@ezyang ezyang requested a review from laithsakka July 16, 2025 03:25
@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 16, 2025

#158385
thanks

@bobrenjc93
Copy link
Contributor

Can you please include more detail in your summary? In particular, it'd be helpful if you could describe the change and why you didn't opt into any precision adjustments (similar to the comment in the other PR) and why that's sound.

Taking a step back, i do wonder if the "right" fix is to not run tensorify on non placeholder item calls. Did you try that?

@laithsakka
Copy link
Contributor

I closed my PR, I will leave this for @thenumberouscode and @bobrenjc93 since he knows this part more ,

@laithsakka laithsakka requested a review from bobrenjc93 July 16, 2025 18:06
):
dtype = node.args[0].meta["val"].dtype
if dtype != torch.float64:
if dtype != torch.float64 and dtype != torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we find a way to handle all float types and add tests for when input is float16 ..etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we really should cover all the float types. I'll look into it tomorrow.

@bobrenjc93
Copy link
Contributor

I think #158376 (comment) is probably quite related. Would you mind running that test in your PR as well?

@thenumberouscode
Copy link
Contributor Author

Can you please include more detail in your summary? In particular, it'd be helpful if you could describe the change and why you didn't opt into any precision adjustments (similar to the comment in the other PR) and why that's sound.

Taking a step back, i do wonder if the "right" fix is to not run tensorify on non placeholder item calls. Did you try that?

Got it! I'll give it a try tomorrow.

@thenumberouscode
Copy link
Contributor Author

I think #158376 (comment) is probably quite related. Would you mind running that test in your PR as well?

Yeah, I just saw the comment. I'll run the test tomorrow.

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 21, 2025

placeholder

@bobrenjc93 Wow, I tested it with issue repro and it works! Just to make sure I understand what you mean, see the newest files changed:
https://github.com/pytorch/pytorch/pull/158216/files

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 21, 2025

@bobrenjc93 I've read through the entire implementation of the tensorify_python_scalars function, and I have some questions. Could you help me with them? Thanks a lot!

1.The tensorify_python_scalars function will find nodes whose target is torch.ops.aten._local_scalar_dense.default and create a proxy for them if the data type is torch.float64. Why is only this data type needed for the proxy? What about torch.float32 and others? (support torch.float32 is my bug-fixing solution.)
2. Your bug-fixing solution passed, and I'm curious why your fix works. How is the item() call in an ordinary scenario treated as a placeholder? Could you provide a demo for me to investigate?
3. Why do we only care about float variables? Why not integers?

# The general shape of this transformation is to look for Tensor operations
# that take a backed SymFloat as an argument, and then redo them as tensor
# compute (with ints and tensors as inputs). For example, add(Tensor, Scalar)
# can be translated into add(Tensor, Tensor). Because Dynamo has already
# arranged for floats to be Tensor inputs to the graph, for typical float
# compute you can entirely translate the Python float operations into Tensor
# operations with only Tensor inputs.

4.What’s the relationship between SymFloat and a Tensor in pytorch? I’m a bit confused about the concepts, especially regarding: backed
# If all symbols are backed symfloats, we can just specialize the whole node

5. I also noticed the compute_dtype and the call to torch.ops.prims.convert_element_type.default. Why do we need compute_dtype and the call, and why can only these three data types be computed?
https://github.com/pytorch/pytorch/blob/a527e816935957a164d74dd7c5069310b2857695/torch/fx/passes/_tensorify_python_scalars.py#L290C21-L296C26

_computation_dtype_map = {
    torch.bfloat16: torch.float32,
    torch.float16: torch.float32,
    torch.complex32: torch.complex64,
}

def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
    return _computation_dtype_map.get(dtype, dtype)

@thenumberouscode thenumberouscode force-pushed the bugfix_tensorfiy_python_scalars branch from 3fa7d36 to 33a0a97 Compare July 21, 2025 09:58
@thenumberouscode thenumberouscode requested review from a team and jeffdaily as code owners July 21, 2025 09:58
@thenumberouscode thenumberouscode force-pushed the bugfix_tensorfiy_python_scalars branch from 33a0a97 to 60e4526 Compare July 21, 2025 10:04
@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 21, 2025

#158376 (comment)

@bobrenjc93 I had tried the test on my local dev, but got some problems when executing following command:

git clone https://github.com/huggingface/transformers.git && cd transformers && git checkout 6017f5e8 && pip install -e .[torch,testing]
image looks like a permission issue. when using sudo, another error occur:
sudo pip install -e .[torch,testing]
image

I'm not going to try it again due to time constraints, but if you think the test is necessary, I can fix the errors to run it.

@ezyang
Copy link
Contributor

ezyang commented Jul 21, 2025

  1. Why do we only care about float variables? Why not integers?

We support integer inputs directly because they are very commonly for sizes, so it would be weird to put them into a tensor and then immediately extract them out. Floats on the other hand, we don't support, and it is better to do the support via tensor.

isinstance(a, fx.Node)
and "val" in a.meta
and isinstance(zf := a.meta["val"], torch.SymFloat)
and a.op == "placeholder"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems safe, but I think we probably also want to eventually tensorify item() calls as well (and yes, it seems like we would need to do some careful dtype sleuthing in this case.)

Copy link
Contributor Author

@thenumberouscode thenumberouscode Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang Actually, it's not safe; this bug fix could cause the unit test test_unspecialized_float_multiply_precision to fail.(https://github.com/pytorch/pytorch/actions/runs/16414045402/job/46407765663?pr=158216).
error stack:
image

After a lengthy investigation, I finally found the reason: if we add the constraint "and a.op == placeholder", the node with the op of call_function cannot be wrapped by MetaProxy. As a result, wrap_symfloat will not be called, so there are another three recompilations leading to the failure of test_unspecialized_float_multiply_precision.

elif not config.specialize_float and type(value) is float:
return self.wrap_symfloat(value)

summary:

We should not add the constraint. I believe that the correct bug fix should still support all float types in the following branch. We need to support all float type scalars to be wrapped in a Tensor in order to avoid recompilation issues. Please correct me if I'm wrong.

# Look for tensor.item() calls on placeholders
if (
node is not None
and node.op == "call_function"
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if dtype != torch.float64:
continue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_unspecialized_float_multiply_precision

@ezyang I apologize for the incorrect conclusion above. Upon careful review, I found that the reason test_unspecialized_float_multiply_precision failed is due to the following code:

if isinstance(
(val := node.meta.get("val")),
(torch.SymFloat, torch.SymInt, torch.SymBool),
):
if has_free_symbols(val.node.expr) and all(
symbol_is_type(s, SymT.FLOAT) for s in val.node.expr.free_symbols
):
# If all symbols are backed symfloats, we can just specialize the whole node
# and get more precise guards. eg.
#
# zf = a.item()
# zf2 = zf // 2
# op(.. zf2 ..)
#
# It's better to guard on zf // 2 == 2.0 than zf == 5.0
node.replace_all_uses_with(guard_scalar(val))
graph.erase_node(node)

When the constraint and a.op == placeholder is added, the _local_scalar_dense node is not eliminated. Instead, it gets resolved to a specific value due to the logic described above, which triggers a recompilation issue for that specific value.

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 22, 2025

  1. Why do we only care about float variables? Why not integers?

We support integer inputs directly because they are very commonly for sizes, so it would be weird to put them into a tensor and then immediately extract them out. Floats on the other hand, we don't support, and it is better to do the support via tensor.

Find a possibly related PR: #125325
@ezyang I read this PR, Is its main purpose to unspecialize float to a tensor to avoid recompilation issues?

@laithsakka
Copy link
Contributor

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@thenumberouscode
Copy link
Contributor Author

@pytorchbot merge

@laithsakka Please trigger the workflow again, thanks.

…tensor in FX trace.

add up conversion

add unit test
@thenumberouscode thenumberouscode force-pushed the bugfix_tensorfiy_python_scalars branch from e2f86b9 to 4d89735 Compare August 4, 2025 09:20
@thenumberouscode
Copy link
Contributor Author

@laithsakka Please rerun the workflow. Thanks.

@thenumberouscode
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 5, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@thenumberouscode
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Details for Dev Infra team Raised by workflow job

@thenumberouscode
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Details for Dev Infra team Raised by workflow job

@thenumberouscode
Copy link
Contributor Author

@

@laithsakka Looks like we'll need to re-run this, a workflow job got canceled. By the way, could you give me access to run these workflows? That way I won't have to keep bugging you about it

@thenumberouscode
Copy link
Contributor Author

@laithsakka Please run the workflow again, thanks.

@laithsakka
Copy link
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fx Merged module: dynamo open source release notes: fx release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile on .sum() and .item() calls errors from tensorify_python_scalars

8 participants