关于对PyTorch中nn.Linear的官方API文档解读

torch.nn.Linear(in_features, out_features, bias=True)

1.1 作用

 

pytorch中的全连接层,对于输入的向量x进行线性变换y=xAT+b

 

1.2 参数

 

in_features – 输入的向量的size

 

out_features – 最终返回的输出向量的size

 

bias – If set to False, the layer will not learn an additive bias. Default: True

 

1.3 形状

 

  • Input: (N,∗,Hin) where * means any number of additional dimensions and Hin=in_features

  • Output: (N,∗,Hout) where all but the last dimension are the same shape as the input and Hout=out_features .

 

意思就是输入输出向量不一定非要是两维,保证第一维相同,最后一个dim能和in_features,out_features匹配就行了

 

from torch import nn
import torch

linear=nn.Linear(in_features=64*3,out_features=1)

a=torch.rand(3,7,64*3)

print(a.shape)
print(linear.weight.shape)
b=linear(a)
print(b.shape)

 

/Users/weihaoyang/opt/anaconda3/envs/nlp_chat/bin/python /Users/weihaoyang/PycharmProjects/pytorch练习/test.py
torch.Size([3, 7, 192])
torch.Size([1, 192])
torch.Size([3, 7, 1])

Process finished with exit code 0

 

可以看出,输入是 3,7,192 输出是 3,7,1 ,其中 [1,192]就是这个全连接层的权重w的size, 经过转置之后变为[192,1] 。 [7,192] x [192,1] => [7,1]




上一篇:192. 统计词频


下一篇:文本溢出显示省略号css