KEMBAR78
[dynamo] fine-grained bytecode-source attribution in python 3.11 by williamwen42 · Pull Request #104676 · pytorch/pytorch · GitHub
Skip to content

Conversation

@williamwen42
Copy link
Member

@williamwen42 williamwen42 commented Jul 6, 2023

Stack from ghstack (oldest at bottom):

Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:

import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))

Output:

$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ~^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @anijain2305 @Xia-Weiwen @ipiszy @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104676

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 84b4847:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Jul 6, 2023
ghstack-source-id: 15aeb89
Pull Request resolved: #104676
@williamwen42 williamwen42 changed the title [dynamo] better trace_source logging in python 3.11 [dynamo] fine-grained bytecode-source attribution in python 3.11 Jul 7, 2023
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Jul 8, 2023
ghstack-source-id: 908c3ee
Pull Request resolved: #104676
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Jul 18, 2023
ghstack-source-id: dec45d6
Pull Request resolved: #104676
…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Jul 19, 2023
ghstack-source-id: f9d9a02
Pull Request resolved: #104676
@williamwen42 williamwen42 added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Jul 19, 2023
@williamwen42
Copy link
Member Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x b592c975734ec51d3f44c42a450466c93315f9bf returned non-zero exit code 1

Auto-merging test/dynamo/test_logging.py
Auto-merging test/dynamo/test_misc.py
Auto-merging torch/_dynamo/output_graph.py
Auto-merging torch/_dynamo/resume_execution.py
Auto-merging torch/_dynamo/symbolic_convert.py
Auto-merging torch/_dynamo/utils.py
CONFLICT (content): Merge conflict in torch/_dynamo/utils.py
error: could not apply b592c975734... [dynamo] better trace_source logging in python 3.11
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
Details for Dev Infra team Raised by workflow job

…n 3.11"


Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```shell
$ TORCH_LOGS="trace_call" python playground9.py 
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ^^^^
TRACE call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ^^^^^^^
TRACE call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ^^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Jul 20, 2023
ghstack-source-id: 3b742e8
Pull Request resolved: #104676
@williamwen42
Copy link
Member Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: python version Issues related to specific Python versions topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants