(y_hat.argmax(dim=1) ==lable).sum().cpu().item()

        print(y_hat.argmax(dim=1))         print(y_hat.argmax(dim=1) ==lable)         print((y_hat.argmax(dim=1) ==lable).sum())         print((y_hat.argmax(dim=1) ==lable).sum().cpu())              print((y_hat.argmax(dim=1) ==lable).sum().cpu().item()) 输出:

tensor([4, 4, 5, 0, 4, 2, 8, 5, 8, 2, 4, 4, 4, 2, 8, 4, 1, 8, 2, 5, 0, 7, 4, 4,
6, 5, 6, 2, 5, 3, 5, 4, 4, 8, 4, 5, 2, 4, 2, 4, 6, 2, 5, 6, 5, 8, 4, 4,
4, 2], device='cuda:0')

第一个输出把每行最大的索引输出
tensor([False, False, False, False, False, False, False, False, True, False,
True, False, False, True, False, False, False, False, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False],
device='cuda:0')

第二个输出判断索引和lable是否相等,相等为true否则为false。
tensor(6, device='cuda:0')

第三个输出进行sum求和true算1,flase算0。
tensor(6)

第四个输出将cuda变为cpu
6

第五个item将tensor变为整形

上一篇:【Android 应用开发】Android 开发环境下载地址 -- 百度网盘 adt-bundle android-studio sdk adt 下载


下一篇:numpy.argmax函数的相关用法