KEMBAR78
[inductor] add decompositions for aten.angle by ydwu4 · Pull Request #105609 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Jul 19, 2023

Fixes #105564.

Added tests.

CPU benchmarking result:
Before decomposition:

[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 0 =====
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module):
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, arg0_1: f32[100000]):
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/yidi/local/t.py:5, code: return torch.angle(x)
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         angle: f32[100000] = torch.ops.aten.angle.default(arg0_1);  arg0_1 = None
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return (angle,)
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] 
eager:
per-call time (us): 1069.2930221557617
compiled:
per-call time (us): 742.4068450927734

After decomposition:

[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 0 =====
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module):
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, arg0_1: f32[100000]):
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/yidi/local/t.py:5, code: return torch.angle(x)
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         lt: b8[100000] = torch.ops.aten.lt.Scalar(arg0_1, 0)
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(3.141592653589793, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         where: f32[100000] = torch.ops.aten.where.self(lt, scalar_tensor_1, scalar_tensor);  lt = scalar_tensor_1 = scalar_tensor = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         isnan: b8[100000] = torch.ops.aten.isnan.default(arg0_1);  arg0_1 = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(nan, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         where_1: f32[100000] = torch.ops.aten.where.self(isnan, scalar_tensor_3, scalar_tensor_2);  isnan = scalar_tensor_3 = scalar_tensor_2 = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         add: f32[100000] = torch.ops.aten.add.Tensor(where, where_1);  where = where_1 = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return (add,)
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] 
eager:
per-call time (us): 1228.0082702636719
compiled:
per-call time (us): 83.6038589477539

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105609

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bed8776:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# if x >= 0, return 0
# if x < 0, return pi
# if x is nan, return nan
ret = torch.where(x.real < 0, math.pi, 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ret = torch.where(x.real < 0, math.pi, 0.0)
ret = torch.where(x < 0, math.pi, 0.0)

@jansel
Copy link
Contributor

jansel commented Jul 19, 2023

Can you show an example of the output code? For complex numbers we might just end up using fallbacks.

@ydwu4
Copy link
Contributor Author

ydwu4 commented Jul 19, 2023

Can you show an example of the output code? For complex numbers we might just end up using fallbacks.

Sure! For below code:

import torch
@torch.compile
def f(x):
    return torch.angle(x)

x = torch.tensor([-1, 1, 0, float("inf"), 1j]) # 1j is complex.
f(x)

The output is following. Is it a fallback? It seems properly handled at least for CPU.

(pytorch-3.10) [yidi@devgpu018.ftw1 ~/local/pytorch]$ TORCH_LOGS="output_code" python t.py 
/home/yidi/local/pytorch/torch/_inductor/lowering.py:1296: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
[2023-07-19 16:30:57,985] torch._inductor.graph.__output_code: [INFO] Output code written to: /tmp/torchinductor_yidi/sg/csgfl2ypssdblsuaiktro237jlw3co3hdrzz35uqvdyepvyhfp35.py
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] Output code: 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from ctypes import c_void_p, c_long
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] import torch
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] import math
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] import random
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] import os
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] import tempfile
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from math import inf, nan
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.hooks import run_intermediate_hooks
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import maybe_profile
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from torch import empty_strided, as_strided, device
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codecache import AsyncCompile
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.select_algorithm import extern_kernels
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] aten = torch.ops.aten
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] async_compile = AsyncCompile()
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] cpp_fused_angle_0 = async_compile.cpp('''
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] #include "/tmp/torchinductor_yidi/zr/czrrhd67iy62iqdam5uwroq4ibq3i5oo4yzl6euetoa7k25vfk35.h"
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] extern "C" void kernel(const float* in_ptr0,
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]                        const float* in_ptr1,
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]                        const float* in_ptr2,
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]                        float* out_ptr0)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] {
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     {
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]         #pragma GCC ivdep
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]         for(long i0=static_cast<long>(0L); i0<static_cast<long>(5L); i0+=static_cast<long>(1L))
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]         {
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp0 = in_ptr0[static_cast<long>(2L*i0)];
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp2 = in_ptr1[static_cast<long>(1L + (2L*i0))];
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp3 = in_ptr2[static_cast<long>(2L*i0)];
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp1 = std::isnan(tmp0);
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp4 = std::atan2(tmp2, tmp3);
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp5 = std::numeric_limits<float>::quiet_NaN();
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             auto tmp6 = tmp1 ? tmp5 : tmp4;
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]             out_ptr0[static_cast<long>(i0)] = tmp6;
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]         }
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     }
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] }
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] ''')
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] async_compile.wait(globals())
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] del async_compile
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] def call(args):
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     arg0_1, = args
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     args.clear()
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     assert_size_stride(arg0_1, (5, ), (1, ))
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf0 = aten.view_as_real.default(arg0_1)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf1 = buf0
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     assert_size_stride(buf1, (5, 2), (2, 1))
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     del buf0
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf2 = aten.view_as_real.default(arg0_1)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf3 = buf2
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     assert_size_stride(buf3, (5, 2), (2, 1))
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     del buf2
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf4 = aten.view_as_real.default(arg0_1)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     del arg0_1
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf5 = buf4
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     assert_size_stride(buf5, (5, 2), (2, 1))
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     del buf4
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     buf6 = empty_strided((5, ), (1, ), device='cpu', dtype=torch.float32)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     cpp_fused_angle_0(c_void_p(buf1.data_ptr()), c_void_p(buf3.data_ptr()), c_void_p(buf5.data_ptr()), c_void_p(buf6.data_ptr()))
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     return (buf6, )
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] def benchmark_compiled_module(times=10, repeat=10):
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     from torch._dynamo.testing import rand_strided
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     from torch._inductor.utils import print_performance
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     arg0_1 = rand_strided((5, ), (1, ), device='cpu', dtype=torch.complex64)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     return print_performance(lambda: call([arg0_1]), times=times, repeat=repeat)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] if __name__ == "__main__":
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     from torch._inductor.utils import compiled_module_main
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG]     compiled_module_main('None', benchmark_compiled_module)
[2023-07-19 16:30:57,986] torch._inductor.graph.__output_code: [DEBUG] 

@jansel
Copy link
Contributor

jansel commented Jul 19, 2023

Looks good, thanks!

@ydwu4
Copy link
Contributor Author

ydwu4 commented Jul 20, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 20, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 24, 2023
Tracks #98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: #110740
Approved by: https://github.com/jansel
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Tracks pytorch#98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: pytorch#110740
Approved by: https://github.com/jansel
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Tracks pytorch#98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: pytorch#110740
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

aten.angle

3 participants