KEMBAR78
[deepspeed / m2m_100] make deepspeed zero-3 work with layerdrop by stas00 · Pull Request #16717 · huggingface/transformers · GitHub
Skip to content

Conversation

@stas00
Copy link
Contributor

@stas00 stas00 commented Apr 12, 2022

Same as I had to fix in wav2vec2 it looks that this fix should eventually go to all models that use LayerDrop. At least at the moment Deepspeed is not capable of randomly skipping layers, so this PR uses the same now well tested workaround I used in wav2vec2, where all layers always run when deepspeed zero-3 is detected, but the results are ignored if it was meant to be skipped.

deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)

Perhaps one day Deepspeed will be able to randomly skip layers, at the moment the solution is not the most efficient one. I made a request.

When ZeRO-3 is not used the original code path is taken.

The test exercising this code path will be merged as part of this huge additional tests set PR #12695 (it's been long overdue).

For posterity, the error for this issue will look something like:

RuntimeError: tracing error at step 42: expected the next 2 parameters in the parameter fetch queue to be 
({'id': 26, 'status': 'AVAILABLE', 'numel': 1024, 'ds_numel': 1024, 'shape': (1024,), 'ds_shape': (1024,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {24}}, {'id': 27, 'status': 'AVAILABLE', 'numel': 1024, 'ds_numel': 1024, 'shape': (1024,), 'ds_shape': (1024,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': {24}}) 
but got 
({'id': 115, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 1024, 'shape': (0,), 'ds_shape': (1024,), 'requires_grad': True, 'grad_shape': None, 'persist': True, 'active_sub_modules': set()}, {'id': 116, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 1048576, 'shape': (0,), 'ds_shape': (1024, 1024), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()}).

Fixes: #16688

@patil-suraj, @sgugger

@stas00 stas00 self-assigned this Apr 12, 2022
@stas00 stas00 changed the title [deepspeed / m2m_100] make deepspeed 3 work with layerdrop [deepspeed / m2m_100] make deepspeed zero-3 work with layerdrop Apr 12, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 12, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing this!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

@stas00 stas00 merged commit c21e107 into main Apr 14, 2022
@stas00 stas00 deleted the ds-m2m-layerdrop branch April 14, 2022 13:51
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…ingface#16717)

* [deepspeed / m2m_100] make deepspeed 3 work with layerdrop

* fix

* revert last
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cannot train M2M100 using run_translation.py and DeepSpeed ZeRO stage 3

4 participants