import torch.nn.functional as F
import torch
def pixelshuffle_inv(tensor, scale=2):
N, ch, height, width = tensor.shape
new_ch = ch * (scale * scale)
new_height = height // scale
new_width = width // scale
tensor = tensor.view(N, ch, new_height, scale, new_width, scale)
tensor = tensor.permute(0, 1, 3, 5, 2, 4).contiguous()
tensor = tensor.view(N, ch * (scale ** 2), new_height, new_width)
return tensor
def pixelshuffle(tensor, scale=2):
N, ch, height, width = tensor.shape
new_ch = ch // (scale * scale)
new_height = height * scale
new_width = width * scale
# 重新排列张量
output_tensor = tensor.view(N, new_ch, scale, scale, height, width)
output_tensor = output_tensor.permute(0, 1, 4, 2, 5, 3).contiguous()
output_tensor = output_tensor.view(N, new_ch, new_height, new_width)
return output_tensor
if __name__ == '__main__':
input = torch.randn(1, 3, 256, 256)
scale = 2
unshuffle_ = pixelshuffle_inv(input,scale)
unshuffle_F = F.pixel_unshuffle(input,scale)
print(torch.equal(unshuffle_,unshuffle_F))
print(torch.max(unshuffle_-unshuffle_F))
shuffle_ = pixelshuffle(unshuffle_F,scale)
shuffle_F = F.pixel_shuffle(unshuffle_F,scale)
print(torch.equal(shuffle_,shuffle_F))
print(torch.max(shuffle_-shuffle_F))
运行结果,与官方结果完全一致。