Debug大神代码的过程中,有好多torch语句不熟悉,阅读速度很慢,转眼就忘,顺便在这里记录下。
1、torch.ge
其实就是比较,官网上也有,torch.ge — PyTorch 1.10.0 documentation
一种是比较两个tensor之间,维度要一致,逐个元素进行比较,相同的返回True,否则返回False。
import torch
a = torch.arange(0,6).view(2,3)
b = torch.arange(2,8).view(2,3)
torch.ge(a,b)
第二种是和常数比较大小,大于等于该数返回1,否则返回0。
2、torch.eq
与torch.ge用法类似,也是两种,但要注意区别,torch.eq — PyTorch 1.10.0 documentation
第一种是两个tensor比较,元素相同返回True,否则返回False。
第二种是和常数比较,与常数相同返回True,否则返回False。
import torch
a = torch.arange(0,6).view(2,3)
b = torch.arange(2,8).view(2,3)
torch.eq(a,b) #第一种,两个tensor比较
torch.eq(a,1) #第二种,和常数比较
3、torch.where
返回True的位置,得到的是tuple类型。
4、torch.randperm
官方文档 :torch.randperm — PyTorch 1.10.0 documentation
torch.randperm(n),表示从0到n,打乱顺序随机排列。