代码实现

import torch.nn.functional as F
import torch
def pixelshuffle_inv(tensor, scale=2):
    N, ch, height, width = tensor.shape
    new_ch = ch * (scale * scale)
    new_height = height // scale
    new_width = width // scale
    
    tensor = tensor.view(N, ch, new_height, scale, new_width, scale)
    tensor = tensor.permute(0, 1, 3, 5, 2, 4).contiguous()
    tensor = tensor.view(N, ch * (scale ** 2), new_height, new_width)

    return tensor

def pixelshuffle(tensor, scale=2):
    N, ch, height, width = tensor.shape
    new_ch = ch // (scale * scale)
    new_height = height * scale
    new_width = width * scale
    # 重新排列张量
    output_tensor = tensor.view(N, new_ch, scale, scale, height, width)
    output_tensor = output_tensor.permute(0, 1, 4, 2, 5, 3).contiguous()
    output_tensor = output_tensor.view(N, new_ch, new_height, new_width)
    return output_tensor


if __name__ == '__main__':
    input = torch.randn(1, 3, 256, 256)
    scale = 2
    unshuffle_ = pixelshuffle_inv(input,scale)
    unshuffle_F = F.pixel_unshuffle(input,scale)
    print(torch.equal(unshuffle_,unshuffle_F))
    print(torch.max(unshuffle_-unshuffle_F))


    shuffle_ = pixelshuffle(unshuffle_F,scale)
    shuffle_F = F.pixel_shuffle(unshuffle_F,scale)
    print(torch.equal(shuffle_,shuffle_F))
    print(torch.max(shuffle_-shuffle_F))

运行结果,与官方结果完全一致。
在这里插入图片描述

上一篇:高防服务器的优劣势有哪些?


下一篇:【py】python实现矩阵的加、减、点乘、乘法