-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
high prioritymodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
IMHO there is a discrepancy between the docs and code of nn.Linear, when it comes to initialization.
documentation says that the weights are initialized from
uniform ( 1/sqrt(in_ feaures) , 1/sqrt(in_ feaures)):
pytorch/torch/nn/modules/linear.py
Lines 53 to 56 in 0df5740
| weight: the learnable weights of the module of shape | |
| :math:`(\text{out\_features}, \text{in\_features})`. The values are | |
| initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where | |
| :math:`k = \frac{1}{\text{in\_features}}` |
code says that the weights are initialized from
kaiming_uniform
pytorch/torch/nn/modules/linear.py
Lines 88 to 89 in 77721ee
| def reset_parameters(self) -> None: | |
| init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
and that includes factors of sqrt(3), gain based on 'a', and 'fan':
Lines 390 to 395 in 77721ee
| fan = _calculate_correct_fan(tensor, mode) | |
| gain = calculate_gain(nonlinearity, a) | |
| std = gain / math.sqrt(fan) | |
| bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation | |
| with torch.no_grad(): | |
| return tensor.uniform_(-bound, bound) |
Is that an error or am I missing something?
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @brianjo @mruberry @albanD
pdpino, danmerus, yu-xiang-wang and ant-hoagy
Metadata
Metadata
Assignees
Labels
high prioritymodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module