KEMBAR78
Improve NEST GPU Utilization 2/N by MahmoudAshraf97 · Pull Request #14089 · NVIDIA-NeMo/NeMo · GitHub
Skip to content

Conversation

@MahmoudAshraf97
Copy link
Contributor

@MahmoudAshraf97 MahmoudAshraf97 commented Jul 1, 2025

This PR is among a series aimed to improve the training speed and GPU utilization of NEST models #13619

after profiling, the masking module takes a good chunk of the total training loop time while involving no computation, thus reducing the utilization, this PR reduces python for-loops as much as possible, and converts the masking operation to a singe function call

As shown in the following graph, This PR along with #14086 reduces the training step time by almost 17%
{4751020E-438F-4DA7-AE1F-E6F14653A76A}

The green is with the PRs applied

cc @stevehuang52

@github-actions github-actions bot added the ASR label Jul 1, 2025
@stevehuang52 stevehuang52 self-requested a review July 1, 2025 13:54
mask_value = self.mask_embedding.unsqueeze(-1)
masks = torch.zeros_like(input_feats)
maksed_feats = input_feats.clone()
masked_feats = input_feats
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this also change input_feats as well when masked_feats is updated? Ideally we should return the masked_feats while keeping input_feats unchanged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did confirm that by:

module = RandomBlockMasking(80, allow_overlap=False, mask_prob=0.01, block_size=40).cuda()
for i in range(100):
    input = torch.randn(1, 80, 1000).cuda()
    input_len = torch.tensor([1000]).cuda()
    masked_feats, masks = module(input, input_len)
    assert not torch.allclose(masked_feats, input)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks

@stevehuang52
Copy link
Collaborator

@MahmoudAshraf97 could you please fix the DCO error?

Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
@MahmoudAshraf97
Copy link
Contributor Author

@stevehuang52 Fixed

@MahmoudAshraf97
Copy link
Contributor Author

@stevehuang52 Can you rerun CICD?

Copy link
Collaborator

@stevehuang52 stevehuang52 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for the improvement!

@stevehuang52 stevehuang52 merged commit b9371da into NVIDIA-NeMo:main Jul 7, 2025
130 checks passed
@MahmoudAshraf97 MahmoudAshraf97 deleted the nest_masking branch July 7, 2025 16:01
AmirHussein96 pushed a commit to AmirHussein96/NeMo that referenced this pull request Jul 23, 2025
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Signed-off-by: Amir Hussein <amhussein@nvidia.com>
AmirHussein96 pushed a commit to AmirHussein96/NeMo that referenced this pull request Aug 5, 2025
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Signed-off-by: Amir Hussein <amhussein@nvidia.com>
AmirHussein96 pushed a commit to AmirHussein96/NeMo that referenced this pull request Aug 5, 2025
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Signed-off-by: Amir Hussein <amhussein@nvidia.com>
nasretdinovr pushed a commit to nasretdinovr/NeMo that referenced this pull request Aug 8, 2025
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
guyueh1 pushed a commit to guyueh1/NeMo that referenced this pull request Aug 25, 2025
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants