KEMBAR78
[Pytoch][Vulkan] Create context for layernorm by copyrightly · Pull Request #114701 · pytorch/pytorch · GitHub
Skip to content

Conversation

@copyrightly
Copy link
Contributor

Summary:
Layernorm has two arguments weight and bias which are stored as constant tensors on the CPU and they are transferred to GPU at every inference call. We create a context for this op to avoid the repeated passing. Specifically, we

  • created create_layernorm_context and run_layernorm_context in Layernorm.h and Layernorm.cpp
  • registered them in Register.cpp
  • rewrote the graph representation of the op in vulkan_rewrite.cpp

Test Plan:

Numerical test

[luwei@devbig984.prn1 /data/users/luwei/fbsource (b6ccc956c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Recommended: For faster builds try buck2: replace 'buck' with 'buck2'
NOTE: buck-out/ has changed: look for files in fbsource/buck-out/v2/
'buck2 build --show-output //xplat/caffe2:pt_vulkan_api_test_bin' will print the new output paths.


If you are building in fbsource//xplat and have questions, post in 'Cross Platform Dev Discussions': https://fb.workplace.com/groups/xplat.qa

  Targets matching .buckconfig buck2.supported_projects:
  {'//xplat/caffe2:pt_vulkan_api_test_bin': '//xplat'}

  To suppress this warning: touch ~/.config/.dont_hint_buck2

Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.2 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.packed_layer_norm_2d
[       OK ] VulkanAPITest.packed_layer_norm_2d (342 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_3d
[       OK ] VulkanAPITest.packed_layer_norm_3d (284 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_4d
[       OK ] VulkanAPITest.packed_layer_norm_4d (5 ms)
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (28 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (4 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_2d
[       OK ] VulkanAPITest.native_layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_3d
[       OK ] VulkanAPITest.native_layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_4d
[       OK ] VulkanAPITest.native_layer_norm_4d (6 ms)
[----------] 10 tests from VulkanAPITest (679 ms total)

[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (679 ms total)
[  PASSED  ] 10 tests.

Full test result in P888496077, summary as below

[----------] 419 tests from VulkanAPITest (21652 ms total)

[----------] Global test environment tear-down
[==========] 419 tests from 1 test suite ran. (21652 ms total)
[  PASSED  ] 418 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log

Graph representation comparison

We created a model using layer_norm and traced it as below

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer_norm = torch.nn.LayerNorm(normalized_shape=10)

    def forward(self, x):
        return self.layer_norm(x)

# Create an instance of the model
model = MyModel()

# Create a dummy input tensor for tracing
input_tensor = torch.randn(1, 10)

# Use torch.jit.trace to trace the model and generate a graph
traced_model = torch.jit.trace(model, input_tensor)

Then we converted the traced model to Vulkan backend using optimize_for_mobile

from torch.utils import mobile_optimizer

vulkan_model = mobile_optimizer.optimize_for_mobile(
    traced_model, backend="vulkan", preserved_methods=to_preserve
)

Then we can print the graph of the vulkan_model as print(vk_model.graph)

  • Before this diff
  %4 : bool = prim::Constant[value=1](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %5 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %14 : int[] = prim::Constant[value=[10]]()
  %33 : Tensor = aten::to(%x, %53, %30, %31, %31)
  %10 : Tensor = aten::layer_norm(%33, %14, %self.layer_norm.weight, %self.layer_norm.bias, %5, %4), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  • after this diff
  %14 : int[] = prim::Constant[value=[10]]()
  %47 : Tensor = aten::to(%x, %78, %44, %45, %45)
  %16 : Tensor = vulkan_prepack::run_layernorm_context(%47, %14, %17)

Reviewed By: SS-JIA

Differential Revision: D51530478

@pytorch-bot pytorch-bot bot added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR module: vulkan release notes: vulkan release notes category labels Nov 28, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114701

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f96df3b with merge base e891a3b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D51530478

copyrightly added a commit that referenced this pull request Nov 29, 2023
Summary:

`Layernorm` has two arguments weight and bias which are stored as constant tensors on the CPU and they are transferred to GPU at every inference call. We create a context for this op to avoid the repeated passing. Specifically, we
- created `create_layernorm_context` and `run_layernorm_context` in `Layernorm.h` and `Layernorm.cpp`
- registered them in `Register.cpp`
- rewrote the graph representation of the op in `vulkan_rewrite.cpp`

Test Plan:
## Numerical test
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (b6ccc956c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Recommended: For faster builds try buck2: replace 'buck' with 'buck2'
NOTE: buck-out/ has changed: look for files in fbsource/buck-out/v2/
'buck2 build --show-output //xplat/caffe2:pt_vulkan_api_test_bin' will print the new output paths.


If you are building in fbsource//xplat and have questions, post in 'Cross Platform Dev Discussions': https://fb.workplace.com/groups/xplat.qa

  Targets matching .buckconfig buck2.supported_projects:
  {'//xplat/caffe2:pt_vulkan_api_test_bin': '//xplat'}

  To suppress this warning: touch ~/.config/.dont_hint_buck2

Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.2 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.packed_layer_norm_2d
[       OK ] VulkanAPITest.packed_layer_norm_2d (342 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_3d
[       OK ] VulkanAPITest.packed_layer_norm_3d (284 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_4d
[       OK ] VulkanAPITest.packed_layer_norm_4d (5 ms)
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (28 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (4 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_2d
[       OK ] VulkanAPITest.native_layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_3d
[       OK ] VulkanAPITest.native_layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_4d
[       OK ] VulkanAPITest.native_layer_norm_4d (6 ms)
[----------] 10 tests from VulkanAPITest (679 ms total)

[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (679 ms total)
[  PASSED  ] 10 tests.
```
Full test result in P888496077, summary as below
```
[----------] 419 tests from VulkanAPITest (21652 ms total)

[----------] Global test environment tear-down
[==========] 419 tests from 1 test suite ran. (21652 ms total)
[  PASSED  ] 418 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
```

## Graph representation comparison
We created a model using `layer_norm` and traced it as below
```
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer_norm = torch.nn.LayerNorm(normalized_shape=10)

    def forward(self, x):
        return self.layer_norm(x)

# Create an instance of the model
model = MyModel()

# Create a dummy input tensor for tracing
input_tensor = torch.randn(1, 10)

# Use torch.jit.trace to trace the model and generate a graph
traced_model = torch.jit.trace(model, input_tensor)
```
Then we converted the traced model to Vulkan backend using `optimize_for_mobile`
```
from torch.utils import mobile_optimizer

vulkan_model = mobile_optimizer.optimize_for_mobile(
    traced_model, backend="vulkan", preserved_methods=to_preserve
)
```
Then we can print the graph of the `vulkan_model` as `print(vk_model.graph)`

- Before this diff
```
  %4 : bool = prim::Constant[value=1](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %5 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %14 : int[] = prim::Constant[value=[10]]()
  %33 : Tensor = aten::to(%x, %53, %30, %31, %31)
  %10 : Tensor = aten::layer_norm(%33, %14, %self.layer_norm.weight, %self.layer_norm.bias, %5, %4), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
```

- after this diff
```
  %14 : int[] = prim::Constant[value=[10]]()
  %47 : Tensor = aten::to(%x, %78, %44, %45, %45)
  %16 : Tensor = vulkan_prepack::run_layernorm_context(%47, %14, %17)
```

Reviewed By: SS-JIA

Differential Revision: D51530478
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D51530478

Summary:

`Layernorm` has two arguments weight and bias which are stored as constant tensors on the CPU and they are transferred to GPU at every inference call. We create a context for this op to avoid the repeated passing. Specifically, we
- created `create_layernorm_context` and `run_layernorm_context` in `Layernorm.h` and `Layernorm.cpp`
- registered them in `Register.cpp`
- rewrote the graph representation of the op in `vulkan_rewrite.cpp`

Test Plan:
## Numerical test
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (b6ccc956c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Recommended: For faster builds try buck2: replace 'buck' with 'buck2'
NOTE: buck-out/ has changed: look for files in fbsource/buck-out/v2/
'buck2 build --show-output //xplat/caffe2:pt_vulkan_api_test_bin' will print the new output paths.


If you are building in fbsource//xplat and have questions, post in 'Cross Platform Dev Discussions': https://fb.workplace.com/groups/xplat.qa

  Targets matching .buckconfig buck2.supported_projects:
  {'//xplat/caffe2:pt_vulkan_api_test_bin': '//xplat'}

  To suppress this warning: touch ~/.config/.dont_hint_buck2

Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.2 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.packed_layer_norm_2d
[       OK ] VulkanAPITest.packed_layer_norm_2d (342 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_3d
[       OK ] VulkanAPITest.packed_layer_norm_3d (284 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_4d
[       OK ] VulkanAPITest.packed_layer_norm_4d (5 ms)
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (28 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (4 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_2d
[       OK ] VulkanAPITest.native_layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_3d
[       OK ] VulkanAPITest.native_layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_4d
[       OK ] VulkanAPITest.native_layer_norm_4d (6 ms)
[----------] 10 tests from VulkanAPITest (679 ms total)

[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (679 ms total)
[  PASSED  ] 10 tests.
```
Full test result in P888496077, summary as below
```
[----------] 419 tests from VulkanAPITest (21652 ms total)

[----------] Global test environment tear-down
[==========] 419 tests from 1 test suite ran. (21652 ms total)
[  PASSED  ] 418 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
```

## Graph representation comparison
We created a model using `layer_norm` and traced it as below
```
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer_norm = torch.nn.LayerNorm(normalized_shape=10)

    def forward(self, x):
        return self.layer_norm(x)

# Create an instance of the model
model = MyModel()

# Create a dummy input tensor for tracing
input_tensor = torch.randn(1, 10)

# Use torch.jit.trace to trace the model and generate a graph
traced_model = torch.jit.trace(model, input_tensor)
```
Then we converted the traced model to Vulkan backend using `optimize_for_mobile`
```
from torch.utils import mobile_optimizer

vulkan_model = mobile_optimizer.optimize_for_mobile(
    traced_model, backend="vulkan", preserved_methods=to_preserve
)
```
Then we can print the graph of the `vulkan_model` as `print(vk_model.graph)`

- Before this diff
```
  %4 : bool = prim::Constant[value=1](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %5 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %14 : int[] = prim::Constant[value=[10]]()
  %33 : Tensor = aten::to(%x, %53, %30, %31, %31)
  %10 : Tensor = aten::layer_norm(%33, %14, %self.layer_norm.weight, %self.layer_norm.bias, %5, %4), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
```

- after this diff
```
  %14 : int[] = prim::Constant[value=[10]]()
  %47 : Tensor = aten::to(%x, %78, %44, %45, %45)
  %16 : Tensor = vulkan_prepack::run_layernorm_context(%47, %14, %17)
```

Reviewed By: SS-JIA

Differential Revision: D51530478
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D51530478

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 30, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the export-D51530478 branch December 3, 2023 15:26
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
Summary:
`Layernorm` has two arguments weight and bias which are stored as constant tensors on the CPU and they are transferred to GPU at every inference call. We create a context for this op to avoid the repeated passing. Specifically, we
- created `create_layernorm_context` and `run_layernorm_context` in `Layernorm.h` and `Layernorm.cpp`
- registered them in `Register.cpp`
- rewrote the graph representation of the op in `vulkan_rewrite.cpp`

Test Plan:
## Numerical test
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (b6ccc956c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Recommended: For faster builds try buck2: replace 'buck' with 'buck2'
NOTE: buck-out/ has changed: look for files in fbsource/buck-out/v2/
'buck2 build --show-output //xplat/caffe2:pt_vulkan_api_test_bin' will print the new output paths.

If you are building in fbsource//xplat and have questions, post in 'Cross Platform Dev Discussions': https://fb.workplace.com/groups/xplat.qa

  Targets matching .buckconfig buck2.supported_projects:
  {'//xplat/caffe2:pt_vulkan_api_test_bin': '//xplat'}

  To suppress this warning: touch ~/.config/.dont_hint_buck2

Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.2 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.packed_layer_norm_2d
[       OK ] VulkanAPITest.packed_layer_norm_2d (342 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_3d
[       OK ] VulkanAPITest.packed_layer_norm_3d (284 ms)
[ RUN      ] VulkanAPITest.packed_layer_norm_4d
[       OK ] VulkanAPITest.packed_layer_norm_4d (5 ms)
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (28 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (4 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_2d
[       OK ] VulkanAPITest.native_layer_norm_2d (1 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_3d
[       OK ] VulkanAPITest.native_layer_norm_3d (2 ms)
[ RUN      ] VulkanAPITest.native_layer_norm_4d
[       OK ] VulkanAPITest.native_layer_norm_4d (6 ms)
[----------] 10 tests from VulkanAPITest (679 ms total)

[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (679 ms total)
[  PASSED  ] 10 tests.
```
Full test result in P888496077, summary as below
```
[----------] 419 tests from VulkanAPITest (21652 ms total)

[----------] Global test environment tear-down
[==========] 419 tests from 1 test suite ran. (21652 ms total)
[  PASSED  ] 418 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
```

## Graph representation comparison
We created a model using `layer_norm` and traced it as below
```
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer_norm = torch.nn.LayerNorm(normalized_shape=10)

    def forward(self, x):
        return self.layer_norm(x)

# Create an instance of the model
model = MyModel()

# Create a dummy input tensor for tracing
input_tensor = torch.randn(1, 10)

# Use torch.jit.trace to trace the model and generate a graph
traced_model = torch.jit.trace(model, input_tensor)
```
Then we converted the traced model to Vulkan backend using `optimize_for_mobile`
```
from torch.utils import mobile_optimizer

vulkan_model = mobile_optimizer.optimize_for_mobile(
    traced_model, backend="vulkan", preserved_methods=to_preserve
)
```
Then we can print the graph of the `vulkan_model` as `print(vk_model.graph)`

- Before this diff
```
  %4 : bool = prim::Constant[value=1](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %5 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
  %14 : int[] = prim::Constant[value=[10]]()
  %33 : Tensor = aten::to(%x, %53, %30, %31, %31)
  %10 : Tensor = aten::layer_norm(%33, %14, %self.layer_norm.weight, %self.layer_norm.bias, %5, %4), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
```

- after this diff
```
  %14 : int[] = prim::Constant[value=[10]]()
  %47 : Tensor = aten::to(%x, %78, %44, %45, %45)
  %16 : Tensor = vulkan_prepack::run_layernorm_context(%47, %14, %17)
```

Reviewed By: SS-JIA

Differential Revision: D51530478

Pull Request resolved: pytorch#114701
Approved by: https://github.com/yipjustin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: vulkan release notes: vulkan release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants