KEMBAR78
[aotd] Fix rrelu compilation by IvanKobzarev · Pull Request #136008 · pytorch/pytorch · GitHub
Skip to content

Conversation

@IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Sep 13, 2024

Stack from ghstack (oldest at bottom):

Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA.
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:

python test/functorch/test_aotdispatch.py -k test_rrelu 

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136008

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7201db2 with merge base efed357 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

IvanKobzarev added a commit that referenced this pull request Sep 13, 2024
ghstack-source-id: c19f725
Pull Request resolved: #136008
@IvanKobzarev IvanKobzarev added the topic: not user facing topic category label Sep 13, 2024
Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




[ghstack-poisoned]
Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 16, 2024
ghstack-source-id: 3ee5c50
Pull Request resolved: #136008

is_factory_function: bool = False

is_randomized_result: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more about why we need this flag? The main reason I'm surprised is:

(1) most ops that are involve randomness can be made deterministic by setting various pytorch state, e.g. torch.mamul_seed (doc). I think we can... probably figure out how to reset the state between running the ref/test so the output? (also, we have other rng ops that don't seem to need a special flag in the OpInfo tests today)

(2) inductor does generate randomness eagerly from eager mode, but the tests you're updating aren't actually using inductor (they use AOTDispatcher, aka they are capturing a graph of the same underlying aten ops and replaying them)

Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 17, 2024
ghstack-source-id: c2417ff
Pull Request resolved: #136008
Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 17, 2024
ghstack-source-id: 89e87f5
Pull Request resolved: #136008
Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 23, 2024
ghstack-source-id: 4c985da
Pull Request resolved: #136008
@IvanKobzarev
Copy link
Contributor Author

IvanKobzarev commented Sep 23, 2024

@pytorchbot merge

1 similar comment
@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch rebase origin/main returned non-zero exit code 1

Rebasing (1/1)
Auto-merging test/functorch/test_aotdispatch.py
CONFLICT (content): Merge conflict in test/functorch/test_aotdispatch.py
error: could not apply 1047213a0f... [aotd] Fix rrelu compilation (#136008)
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Could not apply 1047213a0f... [aotd] Fix rrelu compilation (#136008)
Details for Dev Infra team Raised by workflow job

Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 499c36d95d1f268ada73c57e9888befe14d973d6 returned non-zero exit code 1

Auto-merging test/functorch/test_aotdispatch.py
CONFLICT (content): Merge conflict in test/functorch/test_aotdispatch.py
error: could not apply 499c36d95d... [aotd] Fix rrelu compilation
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

Issues:
#135083
#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA. 
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu 
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 24, 2024
ghstack-source-id: 412e9ba
Pull Request resolved: #136008
@junlin-habana
Copy link

When I used this updated code I found another problem, the same code in eager mode did not give the same result as inductor compile.

import torch
def fn(x):
    empty=torch.ones_like(x,memory_format=torch.contiguous_format)
    result = torch._C._nn.rrelu_with_noise(x,empty,0.2,0.8,True)
    return (result,empty)
x = torch.randn(4,4,dtype=torch.bfloat16,requires_grad=True)
compiled_fn = torch.compile(fn, backend="inductor")

res1,res2 = compiled_fn(x)
res3,res4 = fn(x)
print(res2)
print(res4)

output:
res2: tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=torch.bfloat16)
res4: tensor([[1.0000, 0.2490, 1.0000, 1.0000],
[1.0000, 0.7852, 1.0000, 0.7617],
[1.0000, 1.0000, 0.4102, 0.2578],
[0.6328, 1.0000, 1.0000, 1.0000]], dtype=torch.bfloat16)
Rrelu_with_noise changes the value of noise in eager mode, but not in inductor compile.

@IvanKobzarev
Copy link
Contributor Author

IvanKobzarev commented Sep 25, 2024

@junlin-habana Thanks for the testing.

Yes, there is additional problem, that noise argument is not marked as mutated argument in the schema of rrelu. As a result this mutation is not auto-functionalized. We will fix it in follow up diff changing the schema for all rrelu ops.

@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@IvanKobzarev
Copy link
Contributor Author

@junlin-habana I rechecked your test:

    def test_rrelu_noise_mutation(self):
        def fn(x):
            noise = torch.ones_like(x)
            result = torch._C._nn.rrelu_with_noise(x, noise, 0.2, 0.8, True)
            return result, noise
        
        x = -torch.abs(torch.randn(4, 4, requires_grad=True))

        
        ref_y, ref_noise = fn(x)
        print(f"XXX ref_noise")
        print(ref_noise)

        compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
        comp_y, comp_noise = compiled_fn(x)
        print(f"XXX comp_noise")
        print(comp_noise)        

And noise mutation is captured:

XXX ref_noise
tensor([[0.3616, 0.7460, 0.6859, 0.2228],
        [0.6393, 0.5682, 0.3641, 0.3979],
        [0.3886, 0.3187, 0.6494, 0.5895],
        [0.3719, 0.3988, 0.4285, 0.4109]])
XXX comp_noise
tensor([[0.7303, 0.6850, 0.6517, 0.7393],
        [0.6103, 0.6595, 0.7490, 0.4396],
        [0.2660, 0.3525, 0.4600, 0.4670],
        [0.4980, 0.6719, 0.5962, 0.2782]], grad_fn=<CompiledFunctionBackward>)

And aot graph also captures the mutation:

  def forward(self, primals_1: "f32[4, 4][4, 1]cpu"):
       # File: /data/users/ivankobzarev/b/pytorch/test/functorch/test_aotdispatch.py:6008 in fn, code: result = torch._C._nn.rrelu_with_noise(x, noise, 0.2, 0.8, True)
      le: "b8[4, 4][4, 1]cpu" = torch.ops.aten.le.Scalar(primals_1, 0)
      uniform: "f32[4, 4][4, 1]cpu" = torch.ops.aten.uniform.default(primals_1, 0.2, 0.8)
      mul: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mul.Tensor(primals_1, uniform)
      where: "f32[4, 4][4, 1]cpu" = torch.ops.aten.where.self(le, mul, primals_1);  mul = primals_1 = None
      full_default_1: "f32[][]cpu" = torch.ops.aten.full.default([], 1.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
      where_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.where.self(le, uniform, full_default_1);  full_default_1 = None
      return (where, where_1, le, uniform)

Could you please recheck on the latest master?

BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
Issues:
pytorch#135083
pytorch#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA.
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu
```

Pull Request resolved: pytorch#136008
Approved by: https://github.com/bdhirsh
@junlin-habana
Copy link

@junlin-habana I rechecked your test:

    def test_rrelu_noise_mutation(self):
        def fn(x):
            noise = torch.ones_like(x)
            result = torch._C._nn.rrelu_with_noise(x, noise, 0.2, 0.8, True)
            return result, noise
        
        x = -torch.abs(torch.randn(4, 4, requires_grad=True))

        
        ref_y, ref_noise = fn(x)
        print(f"XXX ref_noise")
        print(ref_noise)

        compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
        comp_y, comp_noise = compiled_fn(x)
        print(f"XXX comp_noise")
        print(comp_noise)        

And noise mutation is captured:

XXX ref_noise
tensor([[0.3616, 0.7460, 0.6859, 0.2228],
        [0.6393, 0.5682, 0.3641, 0.3979],
        [0.3886, 0.3187, 0.6494, 0.5895],
        [0.3719, 0.3988, 0.4285, 0.4109]])
XXX comp_noise
tensor([[0.7303, 0.6850, 0.6517, 0.7393],
        [0.6103, 0.6595, 0.7490, 0.4396],
        [0.2660, 0.3525, 0.4600, 0.4670],
        [0.4980, 0.6719, 0.5962, 0.2782]], grad_fn=<CompiledFunctionBackward>)

And aot graph also captures the mutation:

  def forward(self, primals_1: "f32[4, 4][4, 1]cpu"):
       # File: /data/users/ivankobzarev/b/pytorch/test/functorch/test_aotdispatch.py:6008 in fn, code: result = torch._C._nn.rrelu_with_noise(x, noise, 0.2, 0.8, True)
      le: "b8[4, 4][4, 1]cpu" = torch.ops.aten.le.Scalar(primals_1, 0)
      uniform: "f32[4, 4][4, 1]cpu" = torch.ops.aten.uniform.default(primals_1, 0.2, 0.8)
      mul: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mul.Tensor(primals_1, uniform)
      where: "f32[4, 4][4, 1]cpu" = torch.ops.aten.where.self(le, mul, primals_1);  mul = primals_1 = None
      full_default_1: "f32[][]cpu" = torch.ops.aten.full.default([], 1.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
      where_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.where.self(le, uniform, full_default_1);  full_default_1 = None
      return (where, where_1, le, uniform)

Could you please recheck on the latest master?

Ok,thanks for you help, looks good on the latest master.

@LJ-underdog
Copy link
Contributor

@junlin-habana I rechecked your test:

    def test_rrelu_noise_mutation(self):
        def fn(x):
            noise = torch.ones_like(x)
            result = torch._C._nn.rrelu_with_noise(x, noise, 0.2, 0.8, True)
            return result, noise
        
        x = -torch.abs(torch.randn(4, 4, requires_grad=True))

        
        ref_y, ref_noise = fn(x)
        print(f"XXX ref_noise")
        print(ref_noise)

        compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
        comp_y, comp_noise = compiled_fn(x)
        print(f"XXX comp_noise")
        print(comp_noise)        

And noise mutation is captured:

XXX ref_noise
tensor([[0.3616, 0.7460, 0.6859, 0.2228],
        [0.6393, 0.5682, 0.3641, 0.3979],
        [0.3886, 0.3187, 0.6494, 0.5895],
        [0.3719, 0.3988, 0.4285, 0.4109]])
XXX comp_noise
tensor([[0.7303, 0.6850, 0.6517, 0.7393],
        [0.6103, 0.6595, 0.7490, 0.4396],
        [0.2660, 0.3525, 0.4600, 0.4670],
        [0.4980, 0.6719, 0.5962, 0.2782]], grad_fn=<CompiledFunctionBackward>)

And aot graph also captures the mutation:

  def forward(self, primals_1: "f32[4, 4][4, 1]cpu"):
       # File: /data/users/ivankobzarev/b/pytorch/test/functorch/test_aotdispatch.py:6008 in fn, code: result = torch._C._nn.rrelu_with_noise(x, noise, 0.2, 0.8, True)
      le: "b8[4, 4][4, 1]cpu" = torch.ops.aten.le.Scalar(primals_1, 0)
      uniform: "f32[4, 4][4, 1]cpu" = torch.ops.aten.uniform.default(primals_1, 0.2, 0.8)
      mul: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mul.Tensor(primals_1, uniform)
      where: "f32[4, 4][4, 1]cpu" = torch.ops.aten.where.self(le, mul, primals_1);  mul = primals_1 = None
      full_default_1: "f32[][]cpu" = torch.ops.aten.full.default([], 1.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
      where_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.where.self(le, uniform, full_default_1);  full_default_1 = None
      return (where, where_1, le, uniform)

Could you please recheck on the latest master?

@IvanKobzarev
Can you change the dtype of the input x to bfloat16? Like the following, in which case I see that comp_noise is not changed

x = -torch.abs(torch.randn(4, 4, dtype=torch.bfloat16, requires_grad=True))

@IvanKobzarev
Copy link
Contributor Author

@intellinjun Thanks, reproduced. With bf16 it is not mutating noise, smth with dtype convertion in compilation. Will debug it.

@IvanKobzarev
Copy link
Contributor Author

@intellinjun #136784 for tracking as a separate issue

@github-actions github-actions bot deleted the gh/IvanKobzarev/69/head branch October 27, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants