-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[inductor cpp] vectorize embedding lookup #114062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114062
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 654d076 with merge base c77a4a4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I first saw this I figured out that the way I'd implement this is by:
- Having CppCSEVars track what loops do they depend on (you already do this)
- Then generalise Vec checker so that it returns a list of which loops can be vectorised
- Change the vectorised code to work on "per iterloop case". This you have already implemented via that wrapper trick (quite neat)
I am struggling to find where is it's the equivalent of point 2 above, that is, where do you figure out whether an iterloop can be vectorised.
torch/_inductor/codegen/cpp.py
Outdated
| V.kernel.cse.varname_map[s.name].relevant_itervars | ||
| ) | ||
|
|
||
| def is_relevant(self, itervar: sympy.Symbol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps you could call this method depends_on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Changed.
I am struggling to find where is it's the equivalent of point 2 above, that is, where do you figure out whether an iterloop can be vectorised.
Please check the function select_tiling_indices which gets the candidates of itervars that can be vectorized according to their contiguity. And then CppVecKernelChecker is given these candidates to see whether we have problems vectorizing them.
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler
[ghstack-poisoned]
|
@pytorchbot merge |
|
@pytorchbot revert -m 'Sorry for reverting your change, please help fix lint and reland it https://hud.pytorch.org/pytorch/pytorch/commit/2c0474c02d3ac04a429504225d7f1a6536d3b9e6' -c landrace |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@jgong5 your PR has been successfully reverted. |
This reverts commit 2c0474c. Reverted #114062 on behalf of https://github.com/huydhn due to Sorry for reverting your change, please help fix lint and reland it https://hud.pytorch.org/pytorch/pytorch/commit/2c0474c02d3ac04a429504225d7f1a6536d3b9e6 ([comment](#114062 (comment)))
|
@pytorchbot merge |
Merge failedReason: PR 114062 is out of sync with the corresponding revision 387412d on branch origin/gh/jgong5/32/orig that would be merged into main. This usually happens because there is a non ghstack change in the PR. Please sync them and try again (ex. make the changes on origin/gh/jgong5/32/orig and run ghstack). Details for Dev Infra teamRaised by workflow job |
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler
[ghstack-poisoned]
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by
CppVecKernelwhen any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented withCppCSEVariablewhich bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector.CppVecOverridesis improved to propagate these states when the ops are executed.For the added UT
test_embedding_vec, the generated code before this PR is:After this PR, we have:
cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler