Pytorch中的tensor又包括CPU上的数据类型和GPU上的数据类型,一般GPU上的Tensor是CPU上的Tensor加cuda()函数得到。
一般系统默认是torch.FloatTensor类型。例如data = torch.Tensor(2,3)是一个2*3的张量,类型为FloatTensor;
data.cuda()就转换为GPU的张量类型,torch.cuda.FloatTensor类型。
if cuda: dtype = torch.cuda.FloatTensor else: if torch.cuda.is_available(): print("WARNING: You have a CUDA device, so you should probably set cuda=True") dtype = torch.FloatTensor