pytorch中numel函数用于获取张量中元素数目

pytorch中,numel() 函数用于获取张量中元素数目,其中 numel() 可以理解为是 number of elements 的缩写。

例如:

import torch

a = torch.randn(2,3)
b = a.numel()
print(a,b)

# tensor([[-0.4062, -0.8251, -2.2294],
#         [ 0.5109, -1.4237,  0.8322]]) 6

比如实际应用,numel()函数可用于获取模型参数的总数目:

import logging

# model = ...

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

上一篇:Golang 创建 Excel 文件


下一篇:7月18日学习打卡,数据结构堆