pytorch实现神经网络分类功能

pytorch实现神经网络分类功能

以下代码使用pytorch实现神经网络分类功能:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# make fake data
n_data = torch.ones(100, 2)  # a matrix whose size=100*2,element=1
x0 = torch.normal(2 * n_data, 1)  # a matrix whose size=n_data,element=random in normal distribution,means=2,standard
# deviation=1 , in fact the coordinate of all dots of class 0
y0 = torch.zeros(100)
x1 = torch.normal(-2 * n_data, 1)
y1 = torch.ones(100)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # get 2 types of coordinates together,[a,b] as unit
y = torch.cat((y0, y1), ).type(torch.LongTensor)


class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # hidden layer
        self.out = torch.nn.Linear(n_hidden, n_output)  # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))  # activation function for hidden layer
        x = self.out(x)
        return x


net = Net(n_feature=2, n_hidden=10, n_output=2)  # define the network
print(net)  # net architecture

optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()  # loss of classification

plt.ion()  # dynamic plot

for t in range(100):
    out = net(x)  # input x and predict based on x
    loss = loss_func(out, y)

    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients

    if t % 2 == 0:
        # plot and show learning process
        plt.cla()  # clear the picture to avoid overlapping
        prediction = torch.max(out, 1)[1]  # out size=100*2,
        # max(out,1) compare elements of every row,return 100 values and 100 indexes
        # [1] to get indexes
        pred_y = prediction.data.numpy()  # turn to numpy array
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
        # s=square size of a dot
        # cmap is color set
        accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)  # true=1,false=0
        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
        # (1.5,-4) is coordinate of text
        plt.pause(0.1)

plt.ioff()
plt.show()

输出结果:

pytorch实现神经网络分类功能

上一篇:(自用)java博客作业3 Java抽象类


下一篇:Html3