-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[fx][pass] Support converting a float32 tensor to a scalar in FX trace. #158216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx][pass] Support converting a float32 tensor to a scalar in FX trace. #158216
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158216
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit 4d89735 with merge base f636736 ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@anijain2305 Could you please take a look at my PR? Thank you |
|
@anijain2305 Please review the PR at your convenience. |
|
@laithsakka please coordinate with #158385 |
|
|
Can you please include more detail in your summary? In particular, it'd be helpful if you could describe the change and why you didn't opt into any precision adjustments (similar to the comment in the other PR) and why that's sound. Taking a step back, i do wonder if the "right" fix is to not run tensorify on non placeholder item calls. Did you try that? |
|
I closed my PR, I will leave this for @thenumberouscode and @bobrenjc93 since he knows this part more , |
| ): | ||
| dtype = node.args[0].meta["val"].dtype | ||
| if dtype != torch.float64: | ||
| if dtype != torch.float64 and dtype != torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we find a way to handle all float types and add tests for when input is float16 ..etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we really should cover all the float types. I'll look into it tomorrow.
|
I think #158376 (comment) is probably quite related. Would you mind running that test in your PR as well? |
Got it! I'll give it a try tomorrow. |
Yeah, I just saw the comment. I'll run the test tomorrow. |
@bobrenjc93 Wow, I tested it with issue repro and it works! Just to make sure I understand what you mean, see the newest files changed: |
|
@bobrenjc93 I've read through the entire implementation of the tensorify_python_scalars function, and I have some questions. Could you help me with them? Thanks a lot! 1.The tensorify_python_scalars function will find nodes whose target is torch.ops.aten._local_scalar_dense.default and create a proxy for them if the data type is torch.float64. Why is only this data type needed for the proxy? What about torch.float32 and others? (support torch.float32 is my bug-fixing solution.) pytorch/torch/fx/passes/_tensorify_python_scalars.py Lines 40 to 46 in a527e81
4.What’s the relationship between SymFloat and a Tensor in pytorch? I’m a bit confused about the concepts, especially regarding: backed
5. I also noticed the compute_dtype and the call to torch.ops.prims.convert_element_type.default. Why do we need compute_dtype and the call, and why can only these three data types be computed? https://github.com/pytorch/pytorch/blob/a527e816935957a164d74dd7c5069310b2857695/torch/fx/passes/_tensorify_python_scalars.py#L290C21-L296C26 _computation_dtype_map = {
torch.bfloat16: torch.float32,
torch.float16: torch.float32,
torch.complex32: torch.complex64,
}
def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
return _computation_dtype_map.get(dtype, dtype) |
3fa7d36 to
33a0a97
Compare
33a0a97 to
60e4526
Compare
@bobrenjc93 I had tried the test on my local dev, but got some problems when executing following command: git clone https://github.com/huggingface/transformers.git && cd transformers && git checkout 6017f5e8 && pip install -e .[torch,testing]
looks like a permission issue. when using sudo, another error occur:
sudo pip install -e .[torch,testing]
I'm not going to try it again due to time constraints, but if you think the test is necessary, I can fix the errors to run it. |
We support integer inputs directly because they are very commonly for sizes, so it would be weird to put them into a tensor and then immediately extract them out. Floats on the other hand, we don't support, and it is better to do the support via tensor. |
| isinstance(a, fx.Node) | ||
| and "val" in a.meta | ||
| and isinstance(zf := a.meta["val"], torch.SymFloat) | ||
| and a.op == "placeholder" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems safe, but I think we probably also want to eventually tensorify item() calls as well (and yes, it seems like we would need to do some careful dtype sleuthing in this case.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang Actually, it's not safe; this bug fix could cause the unit test test_unspecialized_float_multiply_precision to fail.(https://github.com/pytorch/pytorch/actions/runs/16414045402/job/46407765663?pr=158216).
error stack:

After a lengthy investigation, I finally found the reason: if we add the constraint "and a.op == placeholder", the node with the op of call_function cannot be wrapped by MetaProxy. As a result, wrap_symfloat will not be called, so there are another three recompilations leading to the failure of test_unspecialized_float_multiply_precision.
pytorch/torch/_dynamo/variables/builder.py
Lines 1945 to 1946 in 9b4d938
| elif not config.specialize_float and type(value) is float: | |
| return self.wrap_symfloat(value) |
summary:
We should not add the constraint. I believe that the correct bug fix should still support all float types in the following branch. We need to support all float type scalars to be wrapped in a Tensor in order to avoid recompilation issues. Please correct me if I'm wrong.
pytorch/torch/fx/passes/_tensorify_python_scalars.py
Lines 199 to 208 in 07df6ba
| # Look for tensor.item() calls on placeholders | |
| if ( | |
| node is not None | |
| and node.op == "call_function" | |
| and node.target is torch.ops.aten._local_scalar_dense.default | |
| ): | |
| dtype = node.args[0].meta["val"].dtype | |
| if dtype != torch.float64: | |
| continue | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_unspecialized_float_multiply_precision
@ezyang I apologize for the incorrect conclusion above. Upon careful review, I found that the reason test_unspecialized_float_multiply_precision failed is due to the following code:
pytorch/torch/fx/passes/_tensorify_python_scalars.py
Lines 327 to 345 in d3ce450
| if isinstance( | |
| (val := node.meta.get("val")), | |
| (torch.SymFloat, torch.SymInt, torch.SymBool), | |
| ): | |
| if has_free_symbols(val.node.expr) and all( | |
| symbol_is_type(s, SymT.FLOAT) for s in val.node.expr.free_symbols | |
| ): | |
| # If all symbols are backed symfloats, we can just specialize the whole node | |
| # and get more precise guards. eg. | |
| # | |
| # zf = a.item() | |
| # zf2 = zf // 2 | |
| # op(.. zf2 ..) | |
| # | |
| # It's better to guard on zf // 2 == 2.0 than zf == 5.0 | |
| node.replace_all_uses_with(guard_scalar(val)) | |
| graph.erase_node(node) | |
When the constraint and a.op == placeholder is added, the _local_scalar_dense node is not eliminated. Instead, it gets resolved to a specific value due to the logic described above, which triggers a recompilation issue for that specific value.
Find a possibly related PR: #125325 |
|
@pytorchbot merge |
|
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
@laithsakka Please trigger the workflow again, thanks. |
…tensor in FX trace. add up conversion add unit test
e2f86b9 to
4d89735
Compare
|
@laithsakka Please rerun the workflow. Thanks. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2) Details for Dev Infra teamRaised by workflow job |
|
@ @laithsakka Looks like we'll need to re-run this, a workflow job got canceled. By the way, could you give me access to run these workflows? That way I won't have to keep bugging you about it |
|
@laithsakka Please run the workflow again, thanks. |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…e. (pytorch#158216) Fixes pytorch#158083 Pull Request resolved: pytorch#158216 Approved by: https://github.com/laithsakka


Fixes #158083
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela