GFPGAN源码分析—第七篇

2021SC@SDUSC

源码:archs\gfpganv1_clean_arch.py

本篇主要分析gfpganv1_clean_arch.py下的

class GFPGANv1Clean(nn.Module)类forward( ) 方法

目录

forward( )

(1)设置Style-GAN 编码器

(2)style code

(3)解码

(4)两个参数都为none,在此处并未用到

(5)解码器decoder


forward( )

参数:

            (self,
            x,
            return_latents=False,
            save_feat_path=None,
            load_feat_path=None,
            return_rgb=True,
            randomize_noise=True)

(1)设置Style-GAN 编码器

feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
for i in range(self.log_size - 2):
    feat = self.conv_body_down[i](feat)
    unet_skips.insert(0, feat)
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)

(2)style code

style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
    style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)

(3)解码

for i in range(self.log_size - 2):
    # add unet skip
    feat = feat + unet_skips[i]
    # ResUpLayer
    feat = self.conv_body_up[i](feat)
    # generate scale and shift for SFT layer
    scale = self.condition_scale[i](feat)
    conditions.append(scale.clone())
    shift = self.condition_shift[i](feat)
    conditions.append(shift.clone())
    # generate rgb images
    if return_rgb:
        out_rgbs.append(self.toRGB[i](feat))

(4)两个参数都为none,在此处并未用到

if save_feat_path is not None:
    torch.save(conditions, save_feat_path)
if load_feat_path is not None:
    conditions = torch.load(load_feat_path)
    conditions = [v.cuda() for v in conditions]

(5)解码器decoder

image, _ = self.stylegan_decoder([style_code],
                                 conditions,
                                 return_latents=return_latents,
                                 input_is_latent=self.input_is_latent,
                                 randomize_noise=randomize_noise)

上一篇:PaddleDetection-MaskRcnn相关结构以及优化器


下一篇:DGL GAT