KEMBAR78
Change constant torch.tensor to torch.full by MerHS · Pull Request #20061 · huggingface/transformers · GitHub
Skip to content

Conversation

@MerHS
Copy link
Contributor

@MerHS MerHS commented Nov 4, 2022

What does this PR do?

Change torch.tensor to torch.full from GPT-2 to avoid CPU-GPU synchronization.

Benchmarks with PyTorch Profiler

GPT-2 with torch.tensor

Here's a trace of a single GPT-2 training iteration with 12 GPT-2 blocks, 2 GPUs, and DDP.

From _attn function, there are two torch.tensor calls. Those invoke CPU to GPU memory movement, thus calling cudaStreamSynchronize.

How to fix

From PyTorch Recipes, we can avoid CPU-GPU synchronization by directly calling torch.full instead of torch.tensor or torch.to. Since two torch.tensor create constant tensors, we can change those into torch.full([], ...), and it will behave the same way.

GPT-2 without torch.tensor

After the patch, every cudaStreamSynchronize is gone, and the duration of a single iteration is reduced by 0.5%.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 4, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your PR!
@michaelbenayoun would this cause any issue for our torch FX/ONNX conversions?

@michaelbenayoun
Copy link
Member

For FX, I think this is already tested in the CI so I guess it does not break things.
For the ONNX export, it's not tested but it should not break things IMO.

@sgugger sgugger merged commit 707b12a into huggingface:main Nov 4, 2022
@MerHS MerHS deleted the remove-torch-tensor branch November 4, 2022 14:44
@JingyaHuang
Copy link
Contributor

Following the change, the training with ONNX Runtime breaks as mask_value and attn_weights don't have the same dtype after being traced. Will open a PR to fix this issue.

======================================================================
ERROR: test_ort_trainer (__main__.TestORTTrainer) (model_name='gpt2', dataset_name='sst2', inference_with_ort=False)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_onnxruntime_train.py", line 131, in test_ort_trainer
    train_result = trainer.train()
  File "/workspace/optimum/onnxruntime/trainer.py", line 349, in train
    return inner_training_loop(
  File "/workspace/optimum/onnxruntime/trainer.py", line 615, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2523, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2555, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py", line 371, in _forward
    return ortmodule._torch_module.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py", line 351, in _forward
    return torch_module_ort._execution_manager(torch_module_ort.is_training()).forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 273, in forward
    self._fallback_manager.handle_exception(
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_fallback.py", line 162, in handle_exception
    raise exception
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 210, in forward
    self._initialize_graph_builder()
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 478, in _initialize_graph_builder
    self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:731 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (Where_223).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants