python取数据

M1,M2,M3是形状相同(任意维度和形状都可以,记为shape)的index matrix。
M1的元素是0,1;M2取值于[0:8],M3取值于[0:768]
在data的形状为torch.Size([2, 8, 768])的数据中,取出data[M1,M2,M3]。其形状就是shape,其实就是M1,M2,M3对应,可以看成M1的每个位置填M1,M2,M3的三个索引值分别去三个维度取索引data的那个数据就行了。

上一篇:【无标题】


下一篇:ViT全流程笔记,附代码详解。