莫烦Python_关系拟合(回归)

  这次来见证神经网络是如何通过简单的形式将一群数据用一条线条来表示,或者说是如何在数据当中找到它们的关系,然后用神经网络模型来建立一个可以代表它们关系的线条。
  创建一些假数据来模拟真实的情况,比如一个一元二次函数 y = a ∗ x 2 + b y = a * x^2 + b y=a∗x2+b,给y数据加上一点噪声来更加真实地展示它:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# x data (tensor),shape=(100, 1)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
# noisy y data (tensor),shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

  可以直接运用torch中的体系建立一个神经网络。先定义所有的层属性(__init__()),然后再一层层搭建(forward(x))层与层的关系链接。建立关系的时候,我们会用到激励函数:

class Net(torch.nn.Module):  # 继承torch的Module
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()  # 继承“__init__”功能
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 隐藏层线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 输出层线性输出

    def forward(self, x):  # 这同时也是Module中的forward功能
        # 正向传播输入值,神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 激励函数
        x = self.predict(x)  # 输出值
        return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net)  # 打印net的结构

执行结果:

Net(
  (hidden): Linear(in_features=1, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=1, bias=True)
)

  训练的步骤如下:

# 传入net的所有参数以及学习率
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式(均方差)

plt.ion()

for t in range(200):
    prediction = net(x)  # 喂给net训练数据x,输出预测值
    loss = loss_func(prediction, y)  # 计算两者的误差

    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播,计算参数更新值
    optimizer.step()  # 将参数更新值施加到net的parameters上

    if t % 5 == 0:
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(),
                 fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)

plt.ioff()
plt.show()

莫烦Python_关系拟合(回归)

上一篇:(自用)java博客作业3 Java抽象类


下一篇:PHP图片加文字水印和图片水印方法(鉴于李老师博客因没加水印被盗,特搜集的办法。希望能有用!)