在pytorch中,激活函数的使用方法有两种,分别是:
第一种:
import torch.nn.functional as F
'''
out = F.relu(input)
- 1
- 2
- 3
第二种:
import torch.nn as nn
'''
nn.RuLU()
- 1
- 2
- 3
其实这两种方法都是使用relu激活,只是使用的场景不一样,F.relu()是函数调用,一般使用在foreward函数里。而nn.ReLU()是模块调用,一般在定义网络层的时候使用。
当用print(net)输出时,会有nn.ReLU()层,而F.ReLU()是没有输出的。
import torch.nn as nn
import torch.nn.functional as F
class NET1(nn.Module):
def init(self):
super(NET1, self).init()
self.conv = nn.Conv2d(3, 16, 3, 1, 1)
self.bn = nn.BatchNorm2d(16)
self.re = nn.ReLU() # 模块的激活函数
<span class="token keyword">def</span> <span class="token function">foreward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
out <span class="token operator">=</span> self<span class="token punctuation">.</span>conv<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
out <span class="token operator">=</span> self<span class="token punctuation">.</span>bn<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
out <span class="token operator">=</span> self<span class="token punctuation">.</span>re<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> out
class NET2(nn.Module):
def init(self):
super(NET2, self).init()
self.conv = nn.Conv2d(3, 16, 3, 1, 1)
self.bn = nn.BatchNorm2d(16)
<span class="token keyword">def</span> <span class="token function">foreward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
out <span class="token operator">=</span> self<span class="token punctuation">.</span>conv<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
out <span class="token operator">=</span> self<span class="token punctuation">.</span>bn<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
out <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>out<span class="token punctuation">)</span> <span class="token comment"># 函数的激活函数</span>
<span class="token keyword">return</span> out
net1 = NET1()
net2 = NET2()
print(net1)
print(net2)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34