-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Remove dtype check on meta device #136774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136774
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit c2a37a2 with merge base 69bcf10 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D54175438 |
This PR needs a
|
4e1ded8 to
a652b5e
Compare
|
This pull request was exported from Phabricator. Differential Revision: D54175438 |
Summary: # Background T176105639 | case | embedding bag weight | per_sample_weight | fbgemm lookup | forward in meta | | A | fp32 | fp32 | good | good | | B | fp16 | fp32 | good| failed [check](https://fburl.com/code/k3n3h031) that forces weight dtype == per_sample_weights dtype | | C | fp16 | fp16 | P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call" | good | | D | fp32 | fp16 | N/A | N/A | Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight. When deleting `use_fp32_embedding`, they would fail the [check](https://fburl.com/code/k3n3h031) that forces `weight dtype == per_sample_weights dtype ` in meta_registration. The check is actually not necessary. Is it because the backend fbgemm does support case B. Additionally, later on in the `meta_embedding_bag`, `weight` and `per_sample_weights` don't need to be in the same dtype (https://fburl.com/code/q0tho05h, weight is src, per_sample_weights is scale) for `is_fast_path_index_select`. # This diff Therefore, this diff remove the unnecessary [check](https://fburl.com/code/k3n3h031) to support case B in meta forward. With such, users are able to use fp16 to be the emb bag dtype without the need to force per_sample_weights the same dtype in meta forward (see Test Plan). # Reference diffs to resolve this issue Diff 1: D52591217 This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, `is_meta` also needs to be passed because of case C. fbgemm still does not support per_sample_weights = fp16 (see the above table). Therefore users are forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks. Diff 2: D53232739 Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. However, it would then result in fbgemm issue too and has broken a bunch of prod models. Test Plan: # APS The following command will run icvr_launcher which triggers ads_launcher and run forward in meta device: ``` buck2 run mode/opt -c python.package_style=inplace //aps_models/ads/icvr:icvr_launcher_publish -- mode=mast_ig_fm_when_combo0_uhm_publish launcher.fbl_entitlement=ads_global_tc_ads_score launcher.data_project=oncall_ads_model_platform launcher.tags=[ads_ranking_taxonomy_exlarge_fm_prod] stages.train=false ``` Result: {F1461463993} Reviewed By: ezyang Differential Revision: D54175438
a652b5e to
c2a37a2
Compare
|
This pull request was exported from Phabricator. Differential Revision: D54175438 |
|
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
Latest Update
This diff is no longer needed because we did need the check to exist, to make meta behave the same as other devices, see D54526190.
Background
T176105639
| case | embedding bag weight | per_sample_weight | fbgemm lookup | forward in meta |
| A | fp32 | fp32 | good | good |
| B | fp16 | fp32 | good| failed check that forces weight dtype == per_sample_weights dtype |
| C | fp16 | fp16 | P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call" | good |
| D | fp32 | fp16 | N/A | N/A |
Currently we are in case A. Users need to add
use_fp32_embeddingin training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight. When deletinguse_fp32_embedding, they would fail the check that forcesweight dtype == per_sample_weights dtypein meta_registration.The check is actually not necessary. Is it because the backend fbgemm does support case B. Additionally, later on in the
meta_embedding_bag,weightandper_sample_weightsdon't need to be in the same dtype (https://fburl.com/code/q0tho05h, weight is src, per_sample_weights is scale) foris_fast_path_index_select.This diff
Therefore, this diff remove the unnecessary check to support case B in meta forward. With such, users are able to use fp16 to be the emb bag dtype without the need to force per_sample_weights the same dtype in meta forward (see Test Plan).
Reference diffs to resolve this issue
Diff 1: D52591217
This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However,
is_metaalso needs to be passed because of case C. fbgemm still does not support per_sample_weights = fp16 (see the above table). Therefore users are forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks.Diff 2: D53232739
Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. However, it would then result in fbgemm issue too and has broken a bunch of prod models.
Test Plan:
APS
The following command will run icvr_launcher which triggers ads_launcher and run forward in meta device:
Result:
{F1461463993}
Reviewed By: ezyang
Differential Revision: D54175438