- enumerate(sequence, [start=0])
参数:
sequence – 一个序列、迭代器或其他支持迭代对象。
start – 下标起始位置。
返回值:
列出数据和数据下标,一般用在 for 循环当中。
for i, data in enumerate(train_loader):
inputs, labels = data
print(inputs,shape)
print(labels.shape)
break
# print output:
# torch.Size([64, 1, 28, 28])
# torch.Size([64])
- x = x.view(x.size()[0], -1)
x是多维tensor,使用函数view()变成二维tensor,行数=batch size,列数为每个input的维度。