-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Detect torch function in lists as well #160256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160256
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 0e7dbe4 with merge base 8171d60 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I have reviewed it and some of the code is bad but it "works". Need to improve some performance characteristics for it. |
| return false; | ||
| bool has_torch_func = false; | ||
|
|
||
| for (long idx = 0; idx < size; idx++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The iteration here is the perf problem. Ideally we delay checking the insides until we are parsing. But this may result in a more involved change upstream as we typically assume by the time we parse TF cannot occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use c10::irange here, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's just the color of the shed; the real problem is I'm adding O(n) extra CPython probes for int list arguments. I need to check to see if the overhead is perceptible.
| PyObject* obj, | ||
| int broadcast_size, | ||
| int64_t* failed_idx = nullptr) { | ||
| int64_t* failed_idx = nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reasons we want ptrs here instead of optional reference? Nullptr seems more error prone, especially when wrapping an integer type. We can statically guard against invalid std::optional accesses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-existing condition.
| static bool is_scalar_list( | ||
| PyObject* obj, | ||
| std::vector<PyObject*>* overloaded_args = nullptr) { | ||
| auto tuple = six::isTuple(obj); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Six? Uh we missed this in the upgrade didn't we... just use pybind11 handle APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to do this separately
|
Some not very scientific benchmarking suggests this is something like 40ns overhead per call, where the calls end to end take 2000ns (so like 2% regression or something). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perf hit sounds fair for the benefit!
test/test_overrides.py
Outdated
| # Fallback | ||
| return torch.tensor(42.0) | ||
|
|
||
| def __index__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess __index__ implies __int__ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an LLM test, I can delete it lol
test/test_overrides.py
Outdated
| return torch.ones_like(args[0]) | ||
| return torch.tensor(42.0) | ||
|
|
||
| def __float__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if this doesn't implement __float__ ?
Same question for the __int__ types?
Both when they're first and not first.
I would add error cases for these.
|
|
||
| for (long idx = 0; idx < size; idx++) { | ||
| PyObject* iobj = | ||
| tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how to solve this one, but Sam is going to hunt you down: https://py-free-threading.github.io/porting-extensions/#unsafe-apis-returning-borrowed-references
The tuple side is fine but the list side, you should use PyList_GetItemRef
But then you need conditional decref and handle early exit properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you get mad if I just use PySequence LOL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW we got a lot of these. Maybe I can ask Codex to fix them:
(pytorch-tmp2) ezyang-mac:pytorch-tmp2 ezyang$ git grep PyList_GET_ITEM
functorch/csrc/dim/dim.cpp: PyObject** begin = &PyList_GET_ITEM(tv.ptr(), 0);
functorch/csrc/dim/minpybind.h: return PyList_GET_ITEM(ptr(), i);
torch/_inductor/codecache.py: void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL);
torch/_inductor/codegen/cpp_wrapper_cpu.py: lines += f"{output_arg} = reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n" # noqa: B950
torch/csrc/autograd/init.cpp: if (!THPVariable_Check(PyList_GET_ITEM(o, i))) {
torch/csrc/autograd/init.cpp: PyList_GET_ITEM(o, i), visit_tensor)) {
torch/csrc/autograd/python_variable.h: PyObject* item = PyList_GET_ITEM(pyresult, i);
torch/csrc/dynamo/python_compiled_autograd.cpp: py::cast<c10::SymInt>(PyList_GET_ITEM(pyresult, idx++)));
torch/csrc/dynamo/python_compiled_autograd.cpp: py::cast<c10::SymInt>(PyList_GET_ITEM(fake_ivalue_args, i)));
torch/csrc/dynamo/python_compiled_autograd.cpp: py::cast<c10::SymFloat>(PyList_GET_ITEM(fake_ivalue_args, i)));
torch/csrc/fx/node.cpp: PyObject* elem = PyList_GET_ITEM(a, i); // borrowed ref
torch/csrc/jit/passes/onnx/shape_type_inference.cpp: auto list_elem = PyList_GET_ITEM(output_obj, 0);
torch/csrc/jit/passes/onnx/shape_type_inference.cpp: list_elem = PyList_GET_ITEM(output_obj, i);
torch/csrc/jit/passes/onnx/shape_type_inference.cpp: PyList_GET_ITEM(output_obj, i),
torch/csrc/jit/passes/onnx/shape_type_inference.cpp: PyList_GET_ITEM(unrolled_dict.ptr(), i),
torch/csrc/python_dimname.cpp: tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
torch/csrc/utils.cpp: tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
torch/csrc/utils.cpp: tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
torch/csrc/utils.cpp: tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
torch/csrc/utils/python_arg_parser.cpp: tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.cpp: tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.cpp: tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.cpp: is_tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
torch/csrc/utils/python_arg_parser.h: : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h: : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h: : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h: : PyList_GET_ITEM(arg.get(), idx);
torch/csrc/utils/python_arg_parser.h: tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
torch/csrc/utils/python_arg_parser.h: tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
torch/csrc/utils/python_arg_parser.h: tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
torch/csrc/utils/python_arg_parser.h: tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you get mad if I just use PySequence LOL
As long as you only call it for list and tuple, sounds ok to me :)
BTW we got a lot of these.
Tuple is ok as they're immutable :)
The list ones I though I went through but I guess I missed them :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chatted with @colesbury about this and he said there's basically three ways we can do it:
- Don't worry about it. (Pretty good option imo)
- Use PyList_GetItemRef instead of PyList_GET_ITEM and handle the refcounting
- Lock the list with Py_BEGIN_CRITICAL_SECTION(pyvec); at the beginning
(1) is the easiest
(3) is probably the most correct because you get a consistent view of the list including the size
For the code here, where we're already penny pinching nanoseconds, it's probably better to do (1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok!
|
|
||
| for (Py_ssize_t idx = 0; idx < size; idx++) { | ||
| PyObject* item_ptr = | ||
| is_tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thread safety issue
| auto* obj = PyTuple_GetItem(index_tup.ptr(), i); | ||
| is_tensor_and_append_overloaded(obj, &overridable_args); | ||
| auto r = is_tensor_and_append_overloaded(obj, &overridable_args); | ||
| if (!r && PySequence_Check(obj)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we guaranteed we can't get a Tensor in here? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems pretty guaranteed to me?
bool is_tensor_and_append_overloaded(
PyObject* obj,
std::vector<PyObject*>* overloaded_args) {
if (THPVariable_CheckExact(obj)) {
// torch.Tensor instances (not subclasses, except for Parameter)
return true;
}
if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
// tensor subclasses and unrelated objects with __torch_function__
append_overloaded_tensor(overloaded_args, obj);
return true;
} else if (THPVariable_Check(obj)) {
// tensor subclasses without __torch_function__
return true;
}
return false;
}
|
|
|
This is waiting for final approval! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok!
Let's mark this as BC-breaking so we can nicely track it in the release notes.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| const bool is_tuple = PyTuple_Check(obj); | ||
| const auto size = is_tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perf nitpick: if we're doing this optimization here we should've hoisted it up to lines 965-968 so we don't hit PySequence_Size. Also should only call PyTuple_Check once total.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sent #161998
This function has come up in DTensor perf work, and I had a nitpick on #160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless. [ghstack-poisoned]
This function has come up in DTensor perf work, and I had a nitpick on #160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless. Pull Request resolved: #161998 Approved by: https://github.com/ezyang
This was done exclusively with claude code and I haven't reviewed it yet Signed-off-by: Edward Yang <ezyang@meta.com> ghstack-source-id: f48741d Pull-Request: pytorch#160256
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things. Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: pytorch#160256 Approved by: https://github.com/albanD
…61998) This function has come up in DTensor perf work, and I had a nitpick on pytorch#160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless. Pull Request resolved: pytorch#161998 Approved by: https://github.com/ezyang
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things. Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: pytorch#160256 Approved by: https://github.com/albanD
…61998) This function has come up in DTensor perf work, and I had a nitpick on pytorch#160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless. Pull Request resolved: pytorch#161998 Approved by: https://github.com/ezyang
…61998) This function has come up in DTensor perf work, and I had a nitpick on pytorch#160256 so here it is. I have neither compiled nor measured this, but am reasonably confident it's better nonetheless. Pull Request resolved: pytorch#161998 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.
Signed-off-by: Edward Yang ezyang@meta.com
cc @gchanan