MaxPool2d 的使用
此处我们仍然使用官网自带的数据集进行训练,最后将其可视化
加载数据集和可视化部分在此处不在介绍,若需要了解:
加载数据集:torch.utils.data中的DataLoader数据加载器(附代码)_硕大的蛋的博客-CSDN博客
tensorboard可视化工具:Tensorboard 可视化工具的使用-史上最简单(附代码)_硕大的蛋的博客-CSDN博客
第一步
导入相应的模块和包
import torch.nn as nn from torch.nn import MaxPool2d import torchvision from torch.utils.data import DataLoader from tensorboardX import SummaryWriter
第二步
加载数据
dataset = torchvision.datasets.CIFAR10('../BigData', train=False, transform=torchvision.transforms.ToTensor(), download=True) dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
第三步
创建神经网络
class Gsw(nn.Module): def __init__(self): super(Gsw, self).__init__() self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=False) def forward(self, x): out = self.maxpool1(x) return out
第四步
训练并将其可视化
gsw = Gsw() writer = SummaryWriter('LOGS/012log') for step, data in enumerate(dataloader): imgs, targets = data writer.add_images('input', imgs, step) output = gsw(imgs) writer.add_images('output', output, step)
完整代码
# 开发时间: 2021/11/22 16:26 import torch.nn as nn from torch.nn import MaxPool2d import torchvision from torch.utils.data import DataLoader from tensorboardX import SummaryWriter dataset = torchvision.datasets.CIFAR10('../BigData', train=False, transform=torchvision.transforms.ToTensor(), download=True) dataloader = DataLoader(dataset, batch_size=64, shuffle=True) class Gsw(nn.Module): def __init__(self): super(Gsw, self).__init__() self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=False) def forward(self, x): out = self.maxpool1(x) return out gsw = Gsw() writer = SummaryWriter('LOGS/012log') for step, data in enumerate(dataloader): imgs, targets = data writer.add_images('input', imgs, step) output = gsw(imgs) writer.add_images('output', output, step)