代码原地址:
https://www.mindspore.cn/tutorial/zh-CN/r1.2/model.html
建立神经网络:
import mindspore.nn as nn class LeNet5(nn.Cell): """ Lenet网络结构 """ def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() # 定义所需要的运算 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc2 = nn.Dense(120, 84) self.fc3 = nn.Dense(84, num_class) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): # 使用定义好的运算构建前向网络 x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) x = self.fc3(x) return x model = LeNet5() for m in model.parameters_and_names(): print(m)
import mindspore from mindspore import Tensor import mindspore.nn as nn import numpy as np conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='valid') input_x = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) print(conv2d(input_x).shape)
import mindspore from mindspore import Tensor import mindspore.nn as nn import numpy as np relu = nn.ReLU() input_x = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16) output = relu(input_x) print(output)
import mindspore from mindspore import Tensor import mindspore.nn as nn import numpy as np max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) input_x = Tensor(np.ones([1, 6, 28, 28]), mindspore.float32) print(max_pool2d(input_x).shape)
import mindspore from mindspore import Tensor import mindspore.nn as nn import numpy as np flatten = nn.Flatten() input_x = Tensor(np.ones([1, 16, 5, 5]), mindspore.float32) output = flatten(input_x) print(output.shape)
import mindspore from mindspore import Tensor import mindspore.nn as nn import numpy as np dense = nn.Dense(400, 120, weight_init='normal') input_x = Tensor(np.ones([1, 400]), mindspore.float32) output = dense(input_x) print(output.shape)