查看设备GPU数目

import torch

def try_gpu(i=0):
    """如果存在返回gpu(i), 否则返回cpu()"""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def try_all_gpus():
    """返回该设备上的所有GPU数"""
    devices = [
        torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

a = try_gpu()
b = try_gpu(1)
c = try_all_gpus()
print(a, b, c)

上一篇:Java中的常见异常及其处理


下一篇:Netty入门教程——认识Netty