使用netron工具可视化Pytorch模型
1 安装netron
pip install netron
2 导入包
import netron
import torch.onnx
程序调用
if __name__ == '__main__':
net = vgg()
x = Variable(torch.FloatTensor(16, 3, 40, 40))
y = net(x)
print(y.data.shape)
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(net, x, onnx_path)
netron.start(onnx_path)