KEMBAR78
[Pytorch][Vulkan] var.dim by copyrightly · Pull Request #111965 · pytorch/pytorch · GitHub
Skip to content

Conversation

@copyrightly
Copy link
Contributor

Summary:
We implement torch.var for tensors of 2d to 4d.

By using the mean, sub and pow ops, we can compute the variance as below without adding a new shader.

at::Tensor self_mean = self.mean(opt_dim, true);
at::Tensor output = (self.sub(self_mean).pow(2)).mean(opt_dim, keepdim);

Test Plan:

[luwei@devbig984.prn1 /data/users/luwei/fbsource (2da0640c6)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*var*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *var*
[==========] Running 6 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 6 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.var_2d_unbiased
[       OK ] VulkanAPITest.var_2d_unbiased (322 ms)
[ RUN      ] VulkanAPITest.var_2d_biased
[       OK ] VulkanAPITest.var_2d_biased (0 ms)
[ RUN      ] VulkanAPITest.var_3d_unbiased
[       OK ] VulkanAPITest.var_3d_unbiased (2 ms)
[ RUN      ] VulkanAPITest.var_3d_biased
[       OK ] VulkanAPITest.var_3d_biased (2 ms)
[ RUN      ] VulkanAPITest.var_4d_unbiased
[       OK ] VulkanAPITest.var_4d_unbiased (175 ms)
[ RUN      ] VulkanAPITest.var_4d_biased
[       OK ] VulkanAPITest.var_4d_biased (5 ms)
[----------] 6 tests from VulkanAPITest (508 ms total)

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (508 ms total)
[  PASSED  ] 6 tests.

Reviewed By: yipjustin

Differential Revision: D50398925

@pytorch-bot pytorch-bot bot added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Oct 24, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111965

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit c814495 with merge base 94e90c1 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50398925

copyrightly added a commit that referenced this pull request Oct 26, 2023
Summary:

We implement [`torch.var`](https://pytorch.org/docs/stable/generated/torch.var.html) for tensors of 2d to 4d.

By using the `mean`, `sub` and `pow` ops, we can compute the variance as below without adding a new shader.
```
at::Tensor self_mean = self.mean(opt_dim, true);
at::Tensor output = (self.sub(self_mean).pow(2)).mean(opt_dim, keepdim);
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (2da0640c6)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*var*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *var*
[==========] Running 6 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 6 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.var_2d_unbiased
[       OK ] VulkanAPITest.var_2d_unbiased (322 ms)
[ RUN      ] VulkanAPITest.var_2d_biased
[       OK ] VulkanAPITest.var_2d_biased (0 ms)
[ RUN      ] VulkanAPITest.var_3d_unbiased
[       OK ] VulkanAPITest.var_3d_unbiased (2 ms)
[ RUN      ] VulkanAPITest.var_3d_biased
[       OK ] VulkanAPITest.var_3d_biased (2 ms)
[ RUN      ] VulkanAPITest.var_4d_unbiased
[       OK ] VulkanAPITest.var_4d_unbiased (175 ms)
[ RUN      ] VulkanAPITest.var_4d_biased
[       OK ] VulkanAPITest.var_4d_biased (5 ms)
[----------] 6 tests from VulkanAPITest (508 ms total)

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (508 ms total)
[  PASSED  ] 6 tests.
```

Reviewed By: yipjustin

Differential Revision: D50398925
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50398925

Summary:

We implement [`torch.var`](https://pytorch.org/docs/stable/generated/torch.var.html) for tensors of 2d to 4d.

By using the `mean`, `sub` and `pow` ops, we can compute the variance as below without adding a new shader.
```
at::Tensor self_mean = self.mean(opt_dim, true);
at::Tensor output = (self.sub(self_mean).pow(2)).mean(opt_dim, keepdim);
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (2da0640c6)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*var*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *var*
[==========] Running 6 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 6 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.var_2d_unbiased
[       OK ] VulkanAPITest.var_2d_unbiased (322 ms)
[ RUN      ] VulkanAPITest.var_2d_biased
[       OK ] VulkanAPITest.var_2d_biased (0 ms)
[ RUN      ] VulkanAPITest.var_3d_unbiased
[       OK ] VulkanAPITest.var_3d_unbiased (2 ms)
[ RUN      ] VulkanAPITest.var_3d_biased
[       OK ] VulkanAPITest.var_3d_biased (2 ms)
[ RUN      ] VulkanAPITest.var_4d_unbiased
[       OK ] VulkanAPITest.var_4d_unbiased (175 ms)
[ RUN      ] VulkanAPITest.var_4d_biased
[       OK ] VulkanAPITest.var_4d_biased (5 ms)
[----------] 6 tests from VulkanAPITest (508 ms total)

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (508 ms total)
[  PASSED  ] 6 tests.
```

Reviewed By: yipjustin

Differential Revision: D50398925
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50398925

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D50398925

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 27, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the export-D50398925 branch October 30, 2023 14:23
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Summary:
We implement [`torch.var`](https://pytorch.org/docs/stable/generated/torch.var.html) for tensors of 2d to 4d.

By using the `mean`, `sub` and `pow` ops, we can compute the variance as below without adding a new shader.
```
at::Tensor self_mean = self.mean(opt_dim, true);
at::Tensor output = (self.sub(self_mean).pow(2)).mean(opt_dim, keepdim);
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (2da0640c6)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*var*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *var*
[==========] Running 6 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 6 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.var_2d_unbiased
[       OK ] VulkanAPITest.var_2d_unbiased (322 ms)
[ RUN      ] VulkanAPITest.var_2d_biased
[       OK ] VulkanAPITest.var_2d_biased (0 ms)
[ RUN      ] VulkanAPITest.var_3d_unbiased
[       OK ] VulkanAPITest.var_3d_unbiased (2 ms)
[ RUN      ] VulkanAPITest.var_3d_biased
[       OK ] VulkanAPITest.var_3d_biased (2 ms)
[ RUN      ] VulkanAPITest.var_4d_unbiased
[       OK ] VulkanAPITest.var_4d_unbiased (175 ms)
[ RUN      ] VulkanAPITest.var_4d_biased
[       OK ] VulkanAPITest.var_4d_biased (5 ms)
[----------] 6 tests from VulkanAPITest (508 ms total)

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (508 ms total)
[  PASSED  ] 6 tests.
```

Reviewed By: yipjustin

Differential Revision: D50398925

Pull Request resolved: pytorch#111965
Approved by: https://github.com/yipjustin
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Summary:
We implement [`torch.var`](https://pytorch.org/docs/stable/generated/torch.var.html) for tensors of 2d to 4d.

By using the `mean`, `sub` and `pow` ops, we can compute the variance as below without adding a new shader.
```
at::Tensor self_mean = self.mean(opt_dim, true);
at::Tensor output = (self.sub(self_mean).pow(2)).mean(opt_dim, keepdim);
```

Test Plan:
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (2da0640c6)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*var*"
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *var*
[==========] Running 6 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 6 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.var_2d_unbiased
[       OK ] VulkanAPITest.var_2d_unbiased (322 ms)
[ RUN      ] VulkanAPITest.var_2d_biased
[       OK ] VulkanAPITest.var_2d_biased (0 ms)
[ RUN      ] VulkanAPITest.var_3d_unbiased
[       OK ] VulkanAPITest.var_3d_unbiased (2 ms)
[ RUN      ] VulkanAPITest.var_3d_biased
[       OK ] VulkanAPITest.var_3d_biased (2 ms)
[ RUN      ] VulkanAPITest.var_4d_unbiased
[       OK ] VulkanAPITest.var_4d_unbiased (175 ms)
[ RUN      ] VulkanAPITest.var_4d_biased
[       OK ] VulkanAPITest.var_4d_biased (5 ms)
[----------] 6 tests from VulkanAPITest (508 ms total)

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (508 ms total)
[  PASSED  ] 6 tests.
```

Reviewed By: yipjustin

Differential Revision: D50398925

Pull Request resolved: pytorch#111965
Approved by: https://github.com/yipjustin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: vulkan release notes: vulkan release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants