-
Notifications
You must be signed in to change notification settings - Fork 677
Enable FP8 full finetune distributed #2546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2546
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit dcdeea4 with merge base 7d92c10 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good! I'm excited to see some initial runs
recipes/full_finetune_distributed.py
Outdated
| if self._enable_fp8_training: | ||
| raise ValueError( | ||
| "Float8 training + tensor parallelism is not supported yet" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding this, are there still known gaps that need to be addressed? E.g. of the three issues from @vkuzo's comment it seems like one is closed, one was a red herring, and one is not relevant to us today (as it's related to async TP). Wondering if we're good to just gate on nightly PyTorch instead of disabling entirely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the main issue I was worried about earlier turned out to be due to bad hardware, so we are all good!
are there plans / needs to support asyncTP in torchtune?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo asyncTP is not something that has really been on our radar. But given that it shows some nice speedups, I think it would be worth exploring.
52d97a5 to
4f8682d
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2546 +/- ##
==========================================
- Coverage 66.97% 65.47% -1.51%
==========================================
Files 375 376 +1
Lines 22295 22438 +143
==========================================
- Hits 14932 14691 -241
- Misses 7363 7747 +384 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
b818e03 to
797eb91
Compare
|
this is great! Are there any performance benchmarks we can put in the PR summary? |
| self._clip_grad_norm = cfg.get("clip_grad_norm", None) | ||
| self._checkpoint_client = CheckpointClient(cfg) | ||
| self._enable_fp8_training = cfg.get("enable_fp8_training", False) | ||
| self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo does this level of configuration look good to you? Any other fields you think I should expose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember you mentioned a performance regression in previous experiments, and we then root caused it to the fact that Float8Linear was applied on parts of the model where torch.compile was not applied to. I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo @andrewor14 From torchao source it appears rowwise scaling options do not support force_recompute_fp8_weight_in_bwd by default.
Given we are defaulting enable_fsdp_float8_all_gather=True in the case where a recipe name isn't provided, should we also set force_recompute_fp8_weight_in_bwd or should this be an option here in the config?
I noticed a TODO in torchao to set this to True by default in future - I wonder if we should do that in tune from now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @vkuzo mentioned that force_recompute_fp8_weight_in_bwd is no longer needed after a certain commit. This PR gates on the nightlies so we don't need this flag anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.
Regarding @vkuzo's comment, torchtune does per-layer compile with separate compilation of the loss. You can see the utility we use here. This means that there may be two cases that are not covered:
- output projection (i.e. LM head)
- any linear layers in e.g. a vision encoder for a multimodal model
Given the numbers in the test plan it seems (1) is not a hard blocker. We could try out with e.g. Llama 3.2 Vision just to make sure there are no regressions there. Otherwise I don't have any huge concerns here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Yes we already do (1) in this PR, will try Llama 3.2 Vision
Thanks. Yeah these are still in progress, will add shortly. |
797eb91 to
df526dd
Compare
8a848ca to
1dfa008
Compare
bd94d84 to
af3dad4
Compare
|
I pulled this down to do some testing, but had trouble with tensor parallelism even after tweaking the plan.
This script appears to set |
8ecd483 to
f50deed
Compare
|
@nathan-az I'm running into this error now, even with the original base TP plan (no float8 parallel styles). Do you have more context on this / how you worked around this in the past? Full stack trace: https://gist.github.com/andrewor14/b4058adc32fcb1dafd70574e0eae335d |
1adf239 to
719dae9
Compare
I can implement that op. I'm just curious how we did it in torchtitan since we clearly have TP support there. Do you see any difference in the torchtune vs torchtitan setup that could be causing this? |
|
I expect there's something but it's not obvious to me. Unfortunately this is about where I hit a wall. I was able to change my error message by adding the chunk op to One final note I can add is that I got a different error after disabling Just a thought - is it worth worth landing this with standard FP8 support for FSDP, raising an error that FP8 + TP isn't supported yet, then adding TP support into a follow-up? |
Yeah I think that's a good idea. @ebsmothers does this sound good to you? |
@andrewor14 yeah this is fine with me. Let's file an issue to get things working with TP though |
|
Last note from me until we move to a separate issue for TP + FP8. By implementing Compile created some metadata errors mentioning implementing I created a gist with a slightly more minimal and self-contained test script. It has very little dependency on |
719dae9 to
aaaed80
Compare
|
@nathan-az Great, thank you for all your investigation! I think we do want TP to work with compile so we probably need to investigate further. Will run some additional benchmarks for Llama 3.2 and let's aim to land this PR without TP. |
d4f38ce to
4ae0db8
Compare
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high precision `grad_weight`, which is a bigger improvement than just tensorwise. Similarly, there are no degradations in memory usage or quantized accuracy. ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 6502.143 (+0.000%) 15.917 (+0.000%) 15.917 (+0.000%) 30.090 (+0.000%) fp8_noname 7205.386 (+10.816%) 15.917 (+0.003%) 15.917 (+0.003%) 30.010 (-0.266%) fp8_tensorwise 7222.198 (+11.074%) 15.917 (+0.003%) 15.917 (+0.003%) 30.010 (-0.266%) fp8_rowwise 6387.968 (-1.756%) 15.916 (-0.002%) 15.916 (-0.002%) 29.158 (-3.096%) fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 15.917 (+0.001%) 15.917 (+0.001%) 29.516 (-1.908%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.533 (+0.000) 12.407 (+0.000) fp8_noname 0.533 (+0.000) 12.414 (+0.007) fp8_tensorwise 0.533 (+0.000) 12.412 (+0.005) fp8_rowwise 0.533 (-0.000) 12.420 (+0.013) fp8_rowwise_with_gw_hp 0.534 (+0.001) 12.416 (+0.009) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
4ae0db8 to
dcdeea4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! This looks good to me. Gonna re-run CI (I think there is some unrelated sporadic failure), after that let's land this!
**Summary:** Similar to meta-pytorch#1854. Update `qat_distributed` recipe to mirror `full_finetune_distributed` up until a6db644. The new major feature that is excluded from `qat_distributed` is FP8 finetuning (meta-pytorch#2546), since QAT FP8 is not supported in torchao yet. Diff between full finetune and QAT recipes: P1809370361 ``` diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py ``` **Test Plan:** Finetune: ``` tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full \ epochs=1 \ batch_size=16 \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat \ output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \ metric_logger.log_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \ quantizer.groupsize=32 ``` Quantize: ``` tune run quantize --config quantization \ model._component_=torchtune.models.llama3_2.llama3_2_3b \ checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \ checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \ 'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Eval: ``` tune run eleuther_eval --config eleuther_evaluation \ batch_size=1 \ 'tasks=[wikitext]' \ model._component_=torchtune.models.llama3_2.llama3_2_3b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \ checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \ 'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Results: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ----------------------- ------------------- ----------------- ---------------- ------------------- Llama3.2-3B_alpaca_full 4677.163 (+0.000%) 12.261 (+0.000%) 12.261 (+0.000%) 15.778 (+0.000%) Llama3.2-3B_alpaca_qat 1873.316 (-59.948%) 13.047 (+6.409%) 13.047 (+6.409%) 17.226 (+9.176%) experiment_name hellaswag_acc wikitext_word_perplexity ----------------------- ------------------------------ ------------------------------- Llama3.2-3B_alpaca_full 0.470 quant, 0.534 float 18.563 quant, 12.364 float Llama3.2-3B_alpaca_qat 0.511 quant, recovered 63.043% 13.792 quant, recovered 76.962% ```
Summary: This commit adds FP8 finetuning to the
full_finetune_distributedrecipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release.To use this feature, add the following to your config.yaml:
The default setting uses tensorwise scaling +
enable_fsdp_float8_all_gather=True, which led to the largest speedups in our experiments.Based on #2404 by @nathan-az
Experimentation: All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos:
For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline:
A few more observations here:
fp8_noname)fp8_rowwisewas the worst fp8 configuration, though still marginally better than the baselineFor Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings:
Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high precision
grad_weight, which is a bigger improvement than just tensorwise. Similarly, there are no degradations in memory usage or quantized accuracy.Test Plan:
Experiment command:
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)
Unit tests: