-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[TP] fully rewrite Tensor Parallel APIs #114732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes ghstack-source-id: 4d53187 Pull Request resolved: #114732
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! Should we wait for TIanyu's PR for deprecating PairwiseParallel to land first to avoid any ci breakage?
| random._rng_tracker.distribute_region_enabled = False | ||
|
|
||
| if device_mesh.ndim > 1: | ||
| _deprecate_warnings("tp_mesh_dim", "If you have a 2-D or N-D device_mesh, consider passing in device_mesh[\"tp\"]") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can later add a reference using Less' tutorial on pytorch/examples as well.
|
Also, we want to remove the PairwiseParallel import from here:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like a good PR to me. I didn't go too deep but I am happy to see the reduction in APIs.
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes ghstack-source-id: 5907e64 Pull Request resolved: #114732
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes ghstack-source-id: 0c0268e Pull Request resolved: #114732
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes ghstack-source-id: 2a3cee8 Pull Request resolved: #114732
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) [ghstack-poisoned]
| self.input_layouts = (input_layouts or Replicate(), ) | ||
| self.output_layouts = (output_layouts or Shard(-1), ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still want to keep the check of type of input_layouts and output_layouts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning up and make TP usability better. I think overall this looks good to me and please make sure all test passed.
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) [ghstack-poisoned]
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab mrshenli zhaojuanmao rohan-varma kiukchung d4l3k lucasllc XilunWu tianyu-l Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) [ghstack-poisoned]
Pull Request resolved: #114732 This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes cc @H-Huang @awgu @kwen2501 @fegin @fduwjj @wz337 @wconstab @mrshenli @zhaojuanmao @rohan-varma @kiukchung @d4l3k @LucasLLC @XilunWu @tianyu-l @imported-using-ghimport Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183/) ghstack-source-id: 209014010
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs supposed to be a very thin-wrapper to DTensor APIs, but the current implementation got too messy and buggy. It's really hard to debug what went wrong when using it. It's crucially important for advanced users or developers to understand the API and its implementation easily without going through all different types of functions and utils, so that they could trust what happen under the hood. In particular this PR: * Make ParallelStyle to be a real contract API for parallelize_module to take, each concrete ParallelStyle only needs to implement `apply` to apply the sharding to nn.Module, remove all non-necessary fields. This also enable easier ParallelStyle authoring going forward. * Keep the ColwiseParallel and RowwiseParallel public interface, but refactor them in a way that makes the parameter sharding, inputs and outputs handling lives within the style itself, so that it's easy to understand how Linear/Embedding layers are sharded and how the inputs/outputs transformations are performed. * remove all those private _prepare_input/_prepare_output_fn fields for both ColwiseParallel/RowwiseParallel. Since we throw deprecation messages in nightly for a while and TP is on prototype release, the fields are also private, it should be safe to remove them * Refactor the recently landed PrepareModuleInput/Output style, change output_layouts to desired_input/output_layouts, group the function inside the style itself, no default arguments for these two styles and user need to specify them to think about the sharding layouts. Fixed bugs about not handling `use_local_output` flag. * Make default arguments be None instead of Placement object, this is standard python practice to not have custom object instance as default argument * Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel style, all prepare input/output functions) as we throw deprecation msgs for a while, and in the progress of removing all of them from the tests. * throw deprecation warning for `tp_mesh_dim` as we recomemnd use device mesh slice/indexing instead of manually specify mesh dim * Rewrite all documentations for every ParallelStyle and make the documentation more clear about what each style is doing TODOs: * Rewrite TP tests to adjust for the changes we have in this PR * add more tests to guard the bug fixes Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183) Pull Request resolved: pytorch#114732 Approved by: https://github.com/wz337, https://github.com/fduwjj
Stack from ghstack (oldest at bottom):
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
take, each concrete ParallelStyle only needs to implement
applytoapply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
use_local_outputflag.standard python practice to not have custom object instance as default
argument
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
tp_mesh_dimas we recomemnd use devicemesh slice/indexing instead of manually specify mesh dim
documentation more clear about what each style is doing
TODOs:
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @kiukchung @d4l3k @LucasLLC
Differential Revision: D51761183