netron是微软小哥lutzroeder的一个广受好评的开源项目,地址https://github.com/lutzroeder/Netro
可惜,默认支持的格式中并不包括pytorch,可能当年小哥面试facebook被拒了,:)
Netron supports ONNX (
.onnx
,.pb
,.pbtxt
), Keras (.h5
,.keras
), Core ML (.mlmodel
), Caffe (.caffemodel
,.prototxt
), Caffe2 (predict_net.pb
,predict_net.pbtxt
), MXNet (.model
,-symbol.json
), NCNN (.param
) and TensorFlow Lite (.tflite
).
1. 安装netron
pip install netron
2. 测试代码
由于不支持默认的pytorch模型格式(.pth),因此需要存为onnx,庆幸pytorch支持!
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import netron
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(64, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64)
)
self.conv1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
self.output = nn.Sequential(
nn.Conv2d(64, 1, 3, padding=1, bias=True),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv1(x)
identity = x
x = F.relu(self.block1(x) + identity)
x = self.output(x)
return x
d = torch.rand(1, 3, 416, 416)
m = model()
o = m(d)
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(m, d, onnx_path)
netron.start(onnx_path)
3. 结果
执行上面代码后,会调用本地浏览器打开,形式和tensorboard差不多
Serving ‘onnx_model_name.onnx‘ at http://localhost:8080