KEMBAR78
[Pytorch][Vulkan] layer_norm by copyrightly · Pull Request #112322 · pytorch/pytorch · GitHub
Skip to content

Conversation

@copyrightly
Copy link
Contributor

Summary:
Generalize layer_norm to all tensors of 2d to 4d. Using the mean and var operators in this diff stack, we can compute the layer_norm directly and remove the old shader file layernorm.glsl.

(input - input.mean(normalized_shape, keepdim=True)) / torch.sqrt(input.var(normalized_shape, correction=0, keepdims = True) + eps) * weight + bias

Test Plan:

[luwei@devbig984.prn1 /data/users/luwei/fbsource (0a5028d8c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (69 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (288 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (302 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (8 ms)
[----------] 4 tests from VulkanAPITest (668 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (668 ms total)
[  PASSED  ] 4 tests.

Reviewed By: yipjustin

Differential Revision: D50436726

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112322

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e2b1f14 with merge base eb8af4d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR module: vulkan release notes: vulkan release notes category labels Oct 28, 2023
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50436726

Summary:

Generalize [layer_norm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) to all tensors of 2d to 4d. Using the mean and var operators in this diff stack, we can compute the layer_norm directly and remove the old shader file `layernorm.glsl`.
```
(input - input.mean(normalized_shape, keepdim=True)) / torch.sqrt(input.var(normalized_shape, correction=0, keepdims = True) + eps) * weight + bias
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (0a5028d8c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (69 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (288 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (302 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (8 ms)
[----------] 4 tests from VulkanAPITest (668 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (668 ms total)
[  PASSED  ] 4 tests.
```

Reviewed By: yipjustin

Differential Revision: D50436726
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50436726

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50436726


return convert(v_output);

std::vector<int64_t> dims_to_reduce;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<int64_t> dims_to_reduce;
std::vector<int64_t> dims_to_reduce;
dims_to_reduce.reserve(normalized_shape.size());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Skylion007! I will fix it in a subsequent PR.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 30, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the export-D50436726 branch November 3, 2023 14:26
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Summary:
Generalize [layer_norm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) to all tensors of 2d to 4d. Using the mean and var operators in this diff stack, we can compute the layer_norm directly and remove the old shader file `layernorm.glsl`.
```
(input - input.mean(normalized_shape, keepdim=True)) / torch.sqrt(input.var(normalized_shape, correction=0, keepdims = True) + eps) * weight + bias
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (0a5028d8c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (69 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (288 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (302 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (8 ms)
[----------] 4 tests from VulkanAPITest (668 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (668 ms total)
[  PASSED  ] 4 tests.
```

Reviewed By: yipjustin

Differential Revision: D50436726

Pull Request resolved: pytorch#112322
Approved by: https://github.com/yipjustin
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Summary:
Generalize [layer_norm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) to all tensors of 2d to 4d. Using the mean and var operators in this diff stack, we can compute the layer_norm directly and remove the old shader file `layernorm.glsl`.
```
(input - input.mean(normalized_shape, keepdim=True)) / torch.sqrt(input.var(normalized_shape, correction=0, keepdims = True) + eps) * weight + bias
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (0a5028d8c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.layer_norm_invalid_inputs
[       OK ] VulkanAPITest.layer_norm_invalid_inputs (69 ms)
[ RUN      ] VulkanAPITest.layer_norm_2d
[       OK ] VulkanAPITest.layer_norm_2d (288 ms)
[ RUN      ] VulkanAPITest.layer_norm_3d
[       OK ] VulkanAPITest.layer_norm_3d (302 ms)
[ RUN      ] VulkanAPITest.layer_norm_4d
[       OK ] VulkanAPITest.layer_norm_4d (8 ms)
[----------] 4 tests from VulkanAPITest (668 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (668 ms total)
[  PASSED  ] 4 tests.
```

Reviewed By: yipjustin

Differential Revision: D50436726

Pull Request resolved: pytorch#112322
Approved by: https://github.com/yipjustin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: vulkan release notes: vulkan release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants