-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add rnn args check #3925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rnn args check #3925
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but please check everything for LSTM.
torch/nn/modules/rnn.py
Outdated
| mini_batch, self.hidden_size) | ||
| if self.mode == 'LSTM': | ||
| hidden = hidden[0] | ||
| if tuple(hidden.size()) != expected_hidden_size: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't this PR have a test for these cases as well? It would be nice to add it back before merging.
* Add rnn args check * Check both hidden sizes for LSTM * RNN args check test
|
@zou3519 I have similar problem, RuntimeError: Expected hidden[0] size (1, 64, 256), got (64, 256), i tried different ways but unable to get it. RuntimeError: Expected hidden[0] size (1, 64, 256), got (64, 256), i tried different ways but unable to get it. |
|
@Shandilya21 Please ask a question on the forums or open a (new) issue if you think there is a bug. |
Fixes #3851, #3259
Added a high-level check for arguments to RNNBase (these were moved from arg checks in
cudnn/rnn.py)Test Plan
python test/test_nn.py