pytorch 计算flops和params

pytorch 计算 params

from thop import profile
import torch
from resnet_18 import Resnet_18, resnet18

model = Resnet_18()
input = torch.randn(1, 3, 256, 256)
flops, params = profile(model, inputs = (input))
print(flops)
print(params)

**FPS 计算过程 **

res = [] 
for id, (data, depth, img_name, img_size) in enumerate(test_loader):
    torch.cuda.synchronize()
    start = time.time()
    predict= model_rgb(inputs, depth)  # 有待修改
    torch.cuda.synchronize()
    end = time.time()
    res.append(end-start)
time_sum = 0
for i in res:
    time_sum += i
print("FPS: %f"%(1.0/(time_sum/len(res))))

上一篇:python - 接口自动化 - Requests 调试脚本综合实战


下一篇:swiper7-33. 添加函数开启监听轮播图的各种事件(点击,滑动...)