- a
import torch
align = torch.FloatTensor([3,4,8]).to(torch.long)
torch.repeat_interleave(torch.eye(3),align,dim=1)
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]])