KEMBAR78
Fix NJT linear_backward() memory usage by jbschlosser · Pull Request #141163 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Nov 20, 2024

Stack from ghstack (oldest at bottom):

Fixes #141112

The formula we're using for linear_backward() is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of unsqueeze()). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for grad_output's values before we compute the various matmuls.

std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
std::array<bool, 3> output_mask) {
if (!grad_output.defined()) {
return std::tuple<Tensor, Tensor, Tensor>{Tensor(), Tensor(), Tensor()};
}
Tensor grad_input, grad_weight, grad_bias;
auto grad_output_contiguous = grad_output.contiguous();
auto* nt_grad_output = get_nested_tensor_impl(grad_output_contiguous);
auto* nt_input = get_nested_tensor_impl(input);
TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr);
TORCH_INTERNAL_ASSERT(nt_input != nullptr);
TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_grad_output));
auto grad_output_buffer = nt_grad_output->get_buffer();
auto input_buffer = nt_input->get_buffer();
auto reshaped_grad = grad_output_buffer.reshape({-1, weight.size(0)});
if (output_mask[0]) {
auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1});
auto grad_input_nt_size = nt_input->get_nested_sizes().clone();
grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size);
}
if (output_mask[1]) {
grad_weight =
at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)}));
}
if (output_mask[2]) {
grad_bias = reshaped_grad.sum(0);
}
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}

Testing for correctness is done via existing gradcheck tests (e.g. test_backward_nn_functional_linear). I added a memory usage test but I think it's likely there's a better way to do this.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141163

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit ddfa80b with merge base a440a01 (image):

NEW FAILURE - The following job has failed:

  • linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test (gh)
    RuntimeError: cuDNN version incompatibility: PyTorch was compiled against (9, 5, 1) but found runtime version (9, 1, 0). PyTorch already comes bundled with cuDNN. One option to resolving this error is to ensure PyTorch can find the bundled cuDNN. one possibility is that there is a conflicting cuDNN in LD_LIBRARY_PATH.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jbschlosser jbschlosser added topic: bug fixes topic category release notes: nested tensor Changes that have a direct impact on nested tensors labels Nov 20, 2024
@jbschlosser jbschlosser requested a review from cpuhrsch November 20, 2024 20:42
Fixes #141112

The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
https://github.com/pytorch/pytorch/blob/d5ee1d1b581da8399d604bd661ea5fe454b485d6/aten/src/ATen/native/nested/NestedTensorBackward.cpp#L37-L70

Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

jbschlosser commented Nov 20, 2024

Discussed offline: reset the max memory stat via torch.cuda.reset_max_memory_allocated() and measure max afterwards. If it's too high (in practice, I see over 3 GB allocated during the backward call), fail the test. Assuming this stat is process-isolated, this should work fine (we don't run tests in CI multi-threaded, only multi-process). If the test fails later on, we can revisit this but at least the fix is in :)

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

Fixes #141112

The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
https://github.com/pytorch/pytorch/blob/d5ee1d1b581da8399d604bd661ea5fe454b485d6/aten/src/ATen/native/nested/NestedTensorBackward.cpp#L37-L70

Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Nov 20, 2024
ghstack-source-id: 3785f3a
Pull Request resolved: #141163
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test

Details for Dev Infra team Raised by workflow job

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#141112

The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
https://github.com/pytorch/pytorch/blob/d5ee1d1b581da8399d604bd661ea5fe454b485d6/aten/src/ATen/native/nested/NestedTensorBackward.cpp#L37-L70

Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.
Pull Request resolved: pytorch#141163
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch, https://github.com/soulitzer
@github-actions github-actions bot deleted the gh/jbschlosser/202/head branch December 22, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants