KEMBAR78
Detect torch function in lists as well by ezyang · Pull Request #160256 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Aug 9, 2025

Stack from ghstack (oldest at bottom):

We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.

Signed-off-by: Edward Yang ezyang@meta.com

cc @gchanan

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160256

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 0e7dbe4 with merge base 8171d60 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Aug 9, 2025
This was done exclusively with claude code and I haven't reviewed it yet

Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 0a57888
Pull-Request: #160256
[ghstack-poisoned]
@ezyang ezyang mentioned this pull request Aug 10, 2025
@ezyang ezyang marked this pull request as ready for review August 10, 2025 04:28
@ezyang
Copy link
Contributor Author

ezyang commented Aug 10, 2025

I have reviewed it and some of the code is bad but it "works". Need to improve some performance characteristics for it.

return false;
bool has_torch_func = false;

for (long idx = 0; idx < size; idx++) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iteration here is the perf problem. Ideally we delay checking the insides until we are parsing. But this may result in a more involved change upstream as we typically assume by the time we parse TF cannot occur.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use c10::irange here, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just the color of the shed; the real problem is I'm adding O(n) extra CPython probes for int list arguments. I need to check to see if the overhead is perceptible.

PyObject* obj,
int broadcast_size,
int64_t* failed_idx = nullptr) {
int64_t* failed_idx = nullptr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reasons we want ptrs here instead of optional reference? Nullptr seems more error prone, especially when wrapping an integer type. We can statically guard against invalid std::optional accesses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-existing condition.

static bool is_scalar_list(
PyObject* obj,
std::vector<PyObject*>* overloaded_args = nullptr) {
auto tuple = six::isTuple(obj);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Six? Uh we missed this in the upgrade didn't we... just use pybind11 handle APIs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to do this separately

@ezyang
Copy link
Contributor Author

ezyang commented Aug 10, 2025

Some not very scientific benchmarking suggests this is something like 40ns overhead per call, where the calls end to end take 2000ns (so like 2% regression or something).

ezyang added 2 commits August 10, 2025 16:48
[ghstack-poisoned]
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 11, 2025
This was done exclusively with claude code and I haven't reviewed it yet

Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 2b5d285
Pull-Request: #160256
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf hit sounds fair for the benefit!

# Fallback
return torch.tensor(42.0)

def __index__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess __index__ implies __int__ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an LLM test, I can delete it lol

return torch.ones_like(args[0])
return torch.tensor(42.0)

def __float__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if this doesn't implement __float__ ?
Same question for the __int__ types?
Both when they're first and not first.

I would add error cases for these.


for (long idx = 0; idx < size; idx++) {
PyObject* iobj =
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to solve this one, but Sam is going to hunt you down: https://py-free-threading.github.io/porting-extensions/#unsafe-apis-returning-borrowed-references

The tuple side is fine but the list side, you should use PyList_GetItemRef
But then you need conditional decref and handle early exit properly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you get mad if I just use PySequence LOL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we got a lot of these. Maybe I can ask Codex to fix them:

(pytorch-tmp2) ezyang-mac:pytorch-tmp2 ezyang$ git grep PyList_GET_ITEM
functorch/csrc/dim/dim.cpp:  PyObject** begin = &PyList_GET_ITEM(tv.ptr(), 0);
functorch/csrc/dim/minpybind.h:        return PyList_GET_ITEM(ptr(), i);
torch/_inductor/codecache.py:                void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL);
torch/_inductor/codegen/cpp_wrapper_cpu.py:                lines += f"{output_arg} = reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n"  # noqa: B950
torch/csrc/autograd/init.cpp:        if (!THPVariable_Check(PyList_GET_ITEM(o, i))) {
torch/csrc/autograd/init.cpp:              PyList_GET_ITEM(o, i), visit_tensor)) {
torch/csrc/autograd/python_variable.h:    PyObject* item = PyList_GET_ITEM(pyresult, i);
torch/csrc/dynamo/python_compiled_autograd.cpp:            py::cast<c10::SymInt>(PyList_GET_ITEM(pyresult, idx++)));
torch/csrc/dynamo/python_compiled_autograd.cpp:          py::cast<c10::SymInt>(PyList_GET_ITEM(fake_ivalue_args, i)));
torch/csrc/dynamo/python_compiled_autograd.cpp:          py::cast<c10::SymFloat>(PyList_GET_ITEM(fake_ivalue_args, i)));
torch/csrc/fx/node.cpp:      PyObject* elem = PyList_GET_ITEM(a, i); // borrowed ref
torch/csrc/jit/passes/onnx/shape_type_inference.cpp:        auto list_elem = PyList_GET_ITEM(output_obj, 0);
torch/csrc/jit/passes/onnx/shape_type_inference.cpp:          list_elem = PyList_GET_ITEM(output_obj, i);
torch/csrc/jit/passes/onnx/shape_type_inference.cpp:            PyList_GET_ITEM(output_obj, i),
torch/csrc/jit/passes/onnx/shape_type_inference.cpp:          PyList_GET_ITEM(unrolled_dict.ptr(), i),
torch/csrc/python_dimname.cpp:      tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
torch/csrc/utils.cpp:          tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
torch/csrc/utils.cpp:          tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
torch/csrc/utils.cpp:          tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
torch/csrc/utils/python_arg_parser.cpp:        tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.cpp:        tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.cpp:        tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.cpp:          is_tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.h:                          : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h:                          : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h:                          : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h:                          : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h:        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
torch/csrc/utils/python_arg_parser.h:        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
torch/csrc/utils/python_arg_parser.h:        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
torch/csrc/utils/python_arg_parser.h:        tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);

Copy link
Collaborator

@albanD albanD Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you get mad if I just use PySequence LOL

As long as you only call it for list and tuple, sounds ok to me :)

BTW we got a lot of these.

Tuple is ok as they're immutable :)
The list ones I though I went through but I guess I missed them :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chatted with @colesbury about this and he said there's basically three ways we can do it:

  1. Don't worry about it. (Pretty good option imo)
  2. Use PyList_GetItemRef instead of PyList_GET_ITEM and handle the refcounting
  3. Lock the list with Py_BEGIN_CRITICAL_SECTION(pyvec); at the beginning

(1) is the easiest
(3) is probably the most correct because you get a consistent view of the list including the size

For the code here, where we're already penny pinching nanoseconds, it's probably better to do (1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!


for (Py_ssize_t idx = 0; idx < size; idx++) {
PyObject* item_ptr =
is_tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thread safety issue

[ghstack-poisoned]
@ezyang ezyang requested a review from soulitzer as a code owner August 12, 2025 00:27
[ghstack-poisoned]
auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
is_tensor_and_append_overloaded(obj, &overridable_args);
auto r = is_tensor_and_append_overloaded(obj, &overridable_args);
if (!r && PySequence_Check(obj)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed we can't get a Tensor in here? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems pretty guaranteed to me?

bool is_tensor_and_append_overloaded(
    PyObject* obj,
    std::vector<PyObject*>* overloaded_args) {
  if (THPVariable_CheckExact(obj)) {
    // torch.Tensor instances (not subclasses, except for Parameter)
    return true;
  }

  if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
    // tensor subclasses and unrelated objects with __torch_function__
    append_overloaded_tensor(overloaded_args, obj);
    return true;
  } else if (THPVariable_Check(obj)) {
    // tensor subclasses without __torch_function__
    return true;
  }

  return false;
}

[ghstack-poisoned]
@ezyang ezyang added release notes: python_frontend python frontend release notes category topic: new features topic category labels Aug 16, 2025
[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Aug 16, 2025

Review comments addressed I think I put this comment on the wrong PR lol I did actually resolve most of them, just missed one

@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 16, 2025
@ezyang
Copy link
Contributor Author

ezyang commented Aug 31, 2025

This is waiting for final approval!

ezyang added 2 commits August 31, 2025 15:13
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!
Let's mark this as BC-breaking so we can nicely track it in the release notes.

@albanD albanD added module: bc-breaking Related to a BC-breaking change topic: bc breaking topic category labels Sep 2, 2025
@ezyang
Copy link
Contributor Author

ezyang commented Sep 2, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Comment on lines +972 to +973
const bool is_tuple = PyTuple_Check(obj);
const auto size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perf nitpick: if we're doing this optimization here we should've hoisted it up to lines 965-968 so we don't hit PySequence_Size. Also should only call PyTuple_Check once total.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sent #161998

swolchok added a commit that referenced this pull request Sep 2, 2025
This function has come up in DTensor perf work, and I had a nitpick on #160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless.

[ghstack-poisoned]
swolchok added a commit that referenced this pull request Sep 2, 2025
This function has come up in DTensor perf work, and I had a nitpick on #160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless.

ghstack-source-id: 5d9d3d5
Pull Request resolved: #161998
pytorchmergebot pushed a commit that referenced this pull request Sep 3, 2025
This function has come up in DTensor perf work, and I had a nitpick on #160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless.
Pull Request resolved: #161998
Approved by: https://github.com/ezyang
ezyang added a commit to ezyang/pytorch that referenced this pull request Sep 10, 2025
This was done exclusively with claude code and I haven't reviewed it yet

Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: f48741d
Pull-Request: pytorch#160256
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: pytorch#160256
Approved by: https://github.com/albanD
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…61998)

This function has come up in DTensor perf work, and I had a nitpick on pytorch#160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless.
Pull Request resolved: pytorch#161998
Approved by: https://github.com/ezyang
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: pytorch#160256
Approved by: https://github.com/albanD
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…61998)

This function has come up in DTensor perf work, and I had a nitpick on pytorch#160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless.
Pull Request resolved: pytorch#161998
Approved by: https://github.com/ezyang
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…61998)

This function has come up in DTensor perf work, and I had a nitpick on pytorch#160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless.
Pull Request resolved: pytorch#161998
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the gh/ezyang/3128/head branch October 3, 2025 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: bc-breaking Related to a BC-breaking change release notes: python_frontend python frontend release notes category topic: bc breaking topic category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants