KEMBAR78
move float8_experimental to torchao/float8 by vkuzo · Pull Request #551 · pytorch/ao · GitHub
Skip to content

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 29, 2024

Summary:

This PR moves https://github.com/pytorch-labs/float8_experimental to torchao/float8.

There are no logic changes here. Here is how to reproduce this PR:

  • copy float8_experimental/float8_experimental/* to torchao/float8
  • copy float8_experimental/test/* to test/float8
  • copy float8_experimental/benchmarks/* to benchmarks/float8
  • copy the README over and delete sections which no longer apply (license, installation)
  • replace float8_experimental with torchao.float8 everywhere
  • update tests to skip old PyTorch versions (only nightly supported)
  • update non-emulated tests to require an H100 for now (we can enable CI in separate PR)
  • update distributed tests to not run in CI for now (we can enable in separate PR)

Test Plan:

// run local tests, they pass
./test/float8/test_everything.sh

// run every benchmark in `benchmarks/float8`, they still work

Reviewers:

Subscribers:

Tasks:

Tags:

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/551

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c00b120 with merge base 8fa11a6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 29, 2024
@vkuzo vkuzo requested review from drisspg and msaroufim July 29, 2024 17:24
@vkuzo vkuzo force-pushed the 20240729_move_float8 branch from fbb76ce to a89914c Compare July 29, 2024 18:17
@drisspg
Copy link
Contributor

drisspg commented Jul 29, 2024

If we wanted to change some of the file naming, we do float8/float8_.. might be a good opportunity, I think since its already namespaced under 'float8' folder it would make sense to remove the leading 'float8_'. I know you will likely not want to complicate the move though, so merely an observation

@vkuzo vkuzo force-pushed the 20240729_move_float8 branch from a89914c to 6f6b4d3 Compare July 29, 2024 18:27
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 29, 2024

If we wanted to change some of the file naming, we do float8/float8_.. might be a good opportunity, I think since its already namespaced under 'float8' folder it would make sense to remove the leading 'float8_'. I know you will likely not want to complicate the move though, so merely an observation

I'd accept a PR if someone wants to clean that up. Note: last week we cleaned up the user facing UX and exposed it through the top level __init__.py file, so file renames like you suggest are no longer user facing.

@vkuzo vkuzo force-pushed the 20240729_move_float8 branch 3 times, most recently from aabb497 to 1d96139 Compare July 29, 2024 18:43
Summary:

This PR moves https://github.com/pytorch-labs/float8_experimental to
torchao/float8.

There are no logic changes here. Here is how to reproduce this PR:
* copy float8_experimental/float8_experimental/* to torchao/float8
* copy float8_experimental/test/* to test/float8
* copy float8_experimental/benchmarks/* to benchmarks/float8
* copy the README over and delete sections which no longer apply
  (license, installation)
* replace `float8_experimental` with `torchao.float8` everywhere

Test Plan:

```
// run local tests, they pass
./test/float8/test_everything.sh

// run every benchmark in `benchmarks/float8`, they still work
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20240729_move_float8 branch from 1d96139 to c00b120 Compare July 29, 2024 19:00
@jerryzh168
Copy link
Contributor

thoughts on torchao/float8 v.s. torchao/dtypes/float8?


#### Float8

[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanna mention a topline speedup marketing number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our last public number is from 2023H2, we plan to release new speedups in ~weeks but not ready yet. Will add it here when it's posted.

@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 29, 2024

thoughts on torchao/float8 v.s. torchao/dtypes/float8?

Sure, my reasoning was that float8_experimental is a workflow, similar to torchao/quantization or torchao/sparsity. There is a Float8Tensor object inside of torchao/float8 which is a private API, if in the future that becomes a public API it could make sense to move it into torchao/dtypes.

@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot merged commit e1aee63 into main Jul 30, 2024
@awgu
Copy link
Contributor

awgu commented Jul 30, 2024

what happens to float8_experimental now? is it deprecated?

@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 30, 2024

what happens to float8_experimental now? is it deprecated?

I'm finishing up the migration today, just wanted to verify nothing was broken before I go ahead, which I have verified this morning. float8_experimental will be archived and get a README.md section pointing people here for the new code location.

vkuzo added a commit to pytorch/torchtitan that referenced this pull request Jul 30, 2024
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Jul 30, 2024
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo mentioned this pull request Jul 30, 2024
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Jul 30, 2024
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 13, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e4849
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20f
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After meta-pytorch/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b492495
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9e
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e4
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2a
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd5
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2f
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
meta-pytorch/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
meta-pytorch/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d21
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
meta-pytorch/float8_experimental#332 and
meta-pytorch/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e76
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af802
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4ad
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed683
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbf
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f496
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d31
Pull Request resolved: pytorch#513

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 13, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e4849
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20f
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After meta-pytorch/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b492495
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9e
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e4
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2a
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd5
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2f
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
meta-pytorch/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
meta-pytorch/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d21
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
meta-pytorch/float8_experimental#332 and
meta-pytorch/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e76
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af802
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4ad
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed683
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbf
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f496
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d31
Pull Request resolved: pytorch#513

* add support for torchbench

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e4849
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20f
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After meta-pytorch/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b492495
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9e
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e4
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2a
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd5
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2f
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
meta-pytorch/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
meta-pytorch/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d21
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
meta-pytorch/float8_experimental#332 and
meta-pytorch/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e76
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af802
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4ad
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed683
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbf
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f496
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d31
Pull Request resolved: pytorch#513

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e4849
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20f
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After meta-pytorch/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b492495
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9e
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e4
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2a
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd5
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2f
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
meta-pytorch/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
meta-pytorch/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d21
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
meta-pytorch/float8_experimental#332 and
meta-pytorch/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e76
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af802
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4ad
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed683
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbf
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f496
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d31
Pull Request resolved: pytorch#513

* add support for torchbench

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
tianyu-l pushed a commit to pytorch/torchtitan that referenced this pull request Aug 16, 2024
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <mikekg@meta.com>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* code beautification

* code beautification, move functions together

* make --device fast the default (pytorch#515)

* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <mikekg@meta.com>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>

* add unpacking support (pytorch#525)

* add unpacking support

* fix typos and linter

* perform parallel prefill when possible (pytorch#568)

* perform parallel prefill when possible

* typo

* disable hack

* remove print

* remove debug messages which prevent export

* fixes

* stream results in generate.py (#571)

* remove logging interfering with export

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants