KEMBAR78
nn.Linear weight initalization - uniform or kaiming_uniform? · Issue #57109 · pytorch/pytorch · GitHub
Skip to content

nn.Linear weight initalization - uniform or kaiming_uniform? #57109

@adrianstaniec

Description

@adrianstaniec

IMHO there is a discrepancy between the docs and code of nn.Linear, when it comes to initialization.

documentation says that the weights are initialized from
uniform ( 1/sqrt(in_ feaures) , 1/sqrt(in_ feaures)):

weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`

code says that the weights are initialized from
kaiming_uniform

def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))

and that includes factors of sqrt(3), gain based on 'a', and 'fan':

pytorch/torch/nn/init.py

Lines 390 to 395 in 77721ee

fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)

Is that an error or am I missing something?

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @brianjo @mruberry @albanD

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: docsRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions