KEMBAR78
Enable FP8 full finetune distributed by andrewor14 · Pull Request #2546 · meta-pytorch/torchtune · GitHub
Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Apr 1, 2025

Summary: This commit adds FP8 finetuning to the full_finetune_distributed recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release.

To use this feature, add the following to your config.yaml:

enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp

The default setting uses tensorwise scaling + enable_fsdp_float8_all_gather=True, which led to the largest speedups in our experiments.

Based on #2404 by @nathan-az

Experimentation: All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos:

torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline:

experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)

A few more observations here:

  • The best tok/s improvement was from the default setting (fp8_noname)
  • fp8_rowwise was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings:

experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)

Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high precision grad_weight, which is a bigger improvement than just tensorwise. Similarly, there are no degradations in memory usage or quantized accuracy.

experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    6502.143 (+0.000%)   15.917 (+0.000%)   15.917 (+0.000%)  30.090 (+0.000%)
fp8_noname              7205.386 (+10.816%)  15.917 (+0.003%)   15.917 (+0.003%)  30.010 (-0.266%)
fp8_tensorwise          7222.198 (+11.074%)  15.917 (+0.003%)   15.917 (+0.003%)  30.010 (-0.266%)
fp8_rowwise             6387.968 (-1.756%)   15.916 (-0.002%)   15.916 (-0.002%)  29.158 (-3.096%)
fp8_rowwise_with_gw_hp  7573.698 (+16.480%)  15.917 (+0.001%)   15.917 (+0.001%)  29.516 (-1.908%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.533 (+0.000)   12.407 (+0.000)
fp8_noname              0.533 (+0.000)   12.414 (+0.007)
fp8_tensorwise          0.533 (+0.000)   12.412 (+0.005)
fp8_rowwise             0.533 (-0.000)   12.420 (+0.013)
fp8_rowwise_with_gw_hp  0.534 (+0.001)   12.416 (+0.009)

Test Plan:

Experiment command:

tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"

(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:

pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2546

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit dcdeea4 with merge base 7d92c10 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewor14 andrewor14 marked this pull request as draft April 1, 2025 22:38
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2025
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good! I'm excited to see some initial runs

Comment on lines 551 to 587
if self._enable_fp8_training:
raise ValueError(
"Float8 training + tensor parallelism is not supported yet"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this, are there still known gaps that need to be addressed? E.g. of the three issues from @vkuzo's comment it seems like one is closed, one was a red herring, and one is not relevant to us today (as it's related to async TP). Wondering if we're good to just gate on nightly PyTorch instead of disabling entirely?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main issue I was worried about earlier turned out to be due to bad hardware, so we are all good!

are there plans / needs to support asyncTP in torchtune?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo asyncTP is not something that has really been on our radar. But given that it shows some nice speedups, I think it would be worth exploring.

@andrewor14 andrewor14 force-pushed the fp8-finetuning branch 4 times, most recently from 52d97a5 to 4f8682d Compare April 2, 2025 22:30
@codecov-commenter
Copy link

codecov-commenter commented Apr 2, 2025

Codecov Report

Attention: Patch coverage is 39.47368% with 23 lines in your changes missing coverage. Please review.

Project coverage is 65.47%. Comparing base (8dadbaa) to head (df526dd).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 11 Missing ⚠️
torchtune/training/quantization.py 41.17% 10 Missing ⚠️
torchtune/models/llama3/_parallelism.py 77.77% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2546      +/-   ##
==========================================
- Coverage   66.97%   65.47%   -1.51%     
==========================================
  Files         375      376       +1     
  Lines       22295    22438     +143     
==========================================
- Hits        14932    14691     -241     
- Misses       7363     7747     +384     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@andrewor14 andrewor14 force-pushed the fp8-finetuning branch 2 times, most recently from b818e03 to 797eb91 Compare April 4, 2025 13:44
@vkuzo
Copy link

vkuzo commented Apr 4, 2025

this is great! Are there any performance benchmarks we can put in the PR summary?

self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo does this level of configuration look good to you? Any other fields you think I should expose?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember you mentioned a performance regression in previous experiments, and we then root caused it to the fact that Float8Linear was applied on parts of the model where torch.compile was not applied to. I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.

Copy link
Collaborator

@nathan-az nathan-az Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo @andrewor14 From torchao source it appears rowwise scaling options do not support force_recompute_fp8_weight_in_bwd by default.

Given we are defaulting enable_fsdp_float8_all_gather=True in the case where a recipe name isn't provided, should we also set force_recompute_fp8_weight_in_bwd or should this be an option here in the config?

I noticed a TODO in torchao to set this to True by default in future - I wonder if we should do that in tune from now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @vkuzo mentioned that force_recompute_fp8_weight_in_bwd is no longer needed after a certain commit. This PR gates on the nightlies so we don't need this flag anymore

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.

Regarding @vkuzo's comment, torchtune does per-layer compile with separate compilation of the loss. You can see the utility we use here. This means that there may be two cases that are not covered:

  1. output projection (i.e. LM head)
  2. any linear layers in e.g. a vision encoder for a multimodal model

Given the numbers in the test plan it seems (1) is not a hard blocker. We could try out with e.g. Llama 3.2 Vision just to make sure there are no regressions there. Otherwise I don't have any huge concerns here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Yes we already do (1) in this PR, will try Llama 3.2 Vision

@andrewor14
Copy link
Contributor Author

this is great! Are there any performance benchmarks we can put in the PR summary?

Thanks. Yeah these are still in progress, will add shortly.

@andrewor14 andrewor14 force-pushed the fp8-finetuning branch 2 times, most recently from 8a848ca to 1dfa008 Compare April 8, 2025 19:12
@andrewor14 andrewor14 marked this pull request as ready for review April 8, 2025 19:12
@andrewor14 andrewor14 force-pushed the fp8-finetuning branch 2 times, most recently from bd94d84 to af3dad4 Compare April 8, 2025 19:20
@nathan-az
Copy link
Collaborator

I pulled this down to do some testing, but had trouble with tensor parallelism even after tweaking the plan.

(full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

This script appears to set tensor_parallel_plan but not tensor_parallel_dim - by default this be 1 so only FSDP is used. May need to do some testing setting the tensor_parallel_dim to confirm TP is working as expected. This may also be why you aren't seeing the same errors.

@andrewor14
Copy link
Contributor Author

@nathan-az I'm running into this error now, even with the original base TP plan (no float8 parallel styles). Do you have more context on this / how you worked around this in the past?

[rank0]:   File "/home/andrewor/local/pytorch/torch/distributed/_functional_collectives.py", line 215, in all_gather_tensor
[rank0]:     res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
[rank0]:   File "/home/andrewor/local/ao/torchao/float8/float8_tensor.py", line 375, in __torch_dispatch__
[rank0]:     raise NotImplementedError(f"attempting to run {func}, this is not supported")
[rank0]: NotImplementedError: attempting to run aten.chunk.default, this is not supported

Full stack trace: https://gist.github.com/andrewor14/b4058adc32fcb1dafd70574e0eae335d

@andrewor14 andrewor14 force-pushed the fp8-finetuning branch 2 times, most recently from 1adf239 to 719dae9 Compare April 11, 2025 19:44
@nathan-az
Copy link
Collaborator

how you worked around this in the past

Perfect - that's the exact error I got stuck on after changing the TP plan. My best guess (absolute guess) if that error can be taken at face value is that the chunk op needs to be implemented in torchao. See here. CC @vkuzo does that seem likely?

@andrewor14
Copy link
Contributor Author

Perfect - that's the exact error I got stuck on after changing the TP plan. My best guess (absolute guess) if that error can be taken at face value is that the chunk op needs to be implemented in torchao. See here

I can implement that op. I'm just curious how we did it in torchtitan since we clearly have TP support there. Do you see any difference in the torchtune vs torchtitan setup that could be causing this?

@nathan-az
Copy link
Collaborator

I expect there's something but it's not obvious to me. Unfortunately this is about where I hit a wall. I was able to change my error message by adding the chunk op to ao's split (not sure if this is exactly correct or if there's any nuanced difference), but hit another error that I didn't know how to address.

One final note I can add is that I got a different error after disabling compile about contiguous not being implemented, but not sure if this is a red herring.

Just a thought - is it worth worth landing this with standard FP8 support for FSDP, raising an error that FP8 + TP isn't supported yet, then adding TP support into a follow-up?

@andrewor14
Copy link
Contributor Author

Just a thought - is it worth worth landing this with standard FP8 support for FSDP, raising an error that FP8 + TP isn't supported yet, then adding TP support into a follow-up?

Yeah I think that's a good idea. @ebsmothers does this sound good to you?

@ebsmothers
Copy link
Contributor

Just a thought - is it worth worth landing this with standard FP8 support for FSDP, raising an error that FP8 + TP isn't supported yet, then adding TP support into a follow-up?

Yeah I think that's a good idea. @ebsmothers does this sound good to you?

@andrewor14 yeah this is fine with me. Let's file an issue to get things working with TP though

@nathan-az
Copy link
Collaborator

nathan-az commented Apr 16, 2025

Last note from me until we move to a separate issue for TP + FP8. By implementing chunk in torchao for the Float8Tensor (piggybacking float8_split), I had successful runs (no guarantee of correctness, just no errors), but only when compile=false. This did not work for models with tied weights for the output layer (although this is mostly the case for smaller models for which it's less likely that TP + FP8 would be critical).

Compile created some metadata errors mentioning implementing __force_to_same_metadata__, but implementing this on the Float8Tensor class did not change the error for me.

I created a gist with a slightly more minimal and self-contained test script. It has very little dependency on torchtune imports so hopefully everything that's happening is transparent, and it's useful for debugging. Right now it instantiates the full 8B, but could use a smaller model using a model builder.

@andrewor14
Copy link
Contributor Author

@nathan-az Great, thank you for all your investigation! I think we do want TP to work with compile so we probably need to investigate further. Will run some additional benchmarks for Llama 3.2 and let's aim to land this PR without TP.

@andrewor14 andrewor14 force-pushed the fp8-finetuning branch 5 times, most recently from d4f38ce to 4ae0db8 Compare April 17, 2025 15:36
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high
precision `grad_weight`, which is a bigger improvement than just
tensorwise. Similarly, there are no degradations in memory usage or
quantized accuracy.
```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    6502.143 (+0.000%)   15.917 (+0.000%)   15.917 (+0.000%)  30.090 (+0.000%)
fp8_noname              7205.386 (+10.816%)  15.917 (+0.003%)   15.917 (+0.003%)  30.010 (-0.266%)
fp8_tensorwise          7222.198 (+11.074%)  15.917 (+0.003%)   15.917 (+0.003%)  30.010 (-0.266%)
fp8_rowwise             6387.968 (-1.756%)   15.916 (-0.002%)   15.916 (-0.002%)  29.158 (-3.096%)
fp8_rowwise_with_gw_hp  7573.698 (+16.480%)  15.917 (+0.001%)   15.917 (+0.001%)  29.516 (-1.908%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.533 (+0.000)   12.407 (+0.000)
fp8_noname              0.533 (+0.000)   12.414 (+0.007)
fp8_tensorwise          0.533 (+0.000)   12.412 (+0.005)
fp8_rowwise             0.533 (-0.000)   12.420 (+0.013)
fp8_rowwise_with_gw_hp  0.534 (+0.001)   12.416 (+0.009)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! This looks good to me. Gonna re-run CI (I think there is some unrelated sporadic failure), after that let's land this!

@ebsmothers ebsmothers merged commit 1075d9c into meta-pytorch:main Apr 17, 2025
14 of 17 checks passed
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request May 12, 2025
**Summary:** Similar to meta-pytorch#1854. Update `qat_distributed` recipe to mirror `full_finetune_distributed` up until a6db644. The new major feature that is excluded from `qat_distributed` is FP8 finetuning (meta-pytorch#2546), since QAT FP8 is not supported in torchao yet.

Diff between full finetune and QAT recipes: P1809370361
```
diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py
```

**Test Plan:**

Finetune:
```
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full \
    epochs=1 \
    batch_size=16 \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat \
    output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \
    metric_logger.log_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/metrics \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
    quantizer.groupsize=32
```

Quantize:
```
tune run quantize --config quantization \
    model._component_=torchtune.models.llama3_2.llama3_2_3b \
    checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
    checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \
    checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \
    'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32
```

Eval:
```
tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    'tasks=[wikitext]' \
    model._component_=torchtune.models.llama3_2.llama3_2_3b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0 \
    checkpointer.output_dir=/home/andrewor/local/logs/tune/Llama3.2-3B_alpaca_qat/epoch_0_out \
    'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32
```

Results:
```
experiment_name          tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
-----------------------  -------------------  -----------------  ----------------  -------------------
Llama3.2-3B_alpaca_full  4677.163 (+0.000%)   12.261 (+0.000%)   12.261 (+0.000%)  15.778 (+0.000%)
Llama3.2-3B_alpaca_qat   1873.316 (-59.948%)  13.047 (+6.409%)   13.047 (+6.409%)  17.226 (+9.176%)

experiment_name          hellaswag_acc                   wikitext_word_perplexity
-----------------------  ------------------------------  -------------------------------
Llama3.2-3B_alpaca_full  0.470 quant, 0.534 float        18.563 quant, 12.364 float
Llama3.2-3B_alpaca_qat   0.511 quant, recovered 63.043%  13.792 quant, recovered 76.962%
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants