风格迁移训练实践

前一篇文章分享了Pytorch简单风格迁移的代码,本着不跑挂服务器不死心的态度,不停的增加计算步骤,看看图片融合生成的效果,

为了方便一次性执行,把代码简单改造了一下,与前一篇文章大同小异:

风格迁移训练实践
  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 
  6 from PIL import Image
  7 import matplotlib.pyplot as plt
  8 
  9 import torchvision.transforms as transforms
 10 import torchvision.models as models
 11 import datetime
 12 
 13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 14 
 15 
 16 def get_img_size(img_name):
 17     """
 18     获取图像大小
 19     :param img_name:
 20     :return:
 21     """
 22     im = Image.open(img_name).convert('RGB')
 23     return im, im.height, im.width
 24 
 25 
 26 def image_loader(img, im_h, im_w):
 27     """
 28     加载图像
 29     :param img:
 30     :param im_h:
 31     :param im_w:
 32     :return:
 33     """
 34 
 35     # loader = transforms.Compose([transforms.Resize([im_h, im_w]), transforms.ToTensor()])
 36     loader = transforms.Compose([transforms.Resize([1000, 1000]), transforms.ToTensor()])
 37     im_l = loader(img).unsqueeze(0)
 38     return im_l.to(device, torch.float)
 39 
 40 
 41 def im_show(tensor, save_file_path):
 42     """
 43     显示保存图片
 44     :param tensor:
 45     :param save_file_path:
 46     :return:
 47     """
 48     image = tensor.cpu().clone()
 49     image = image.squeeze(0)
 50     image = transforms.ToPILImage()(image)
 51     plt.imshow(image, aspect='equal')
 52     plt.axis('off')
 53     plt.savefig(save_file_path, bbox_inches='tight', pad_inches=0.0)
 54     plt.pause(0.001)
 55 
 56 
 57 class ContentLoss(nn.Module):
 58     """
 59     内容损失
 60     """
 61 
 62     def __init__(self, target,):
 63         super(ContentLoss, self).__init__()
 64         self.target = target.detach()
 65 
 66     def forward(self, cl_input):
 67         self.loss = F.mse_loss(cl_input, self.target)
 68         return cl_input
 69 
 70 
 71 def gram_matrix(gm_input):
 72     """
 73     风格损失矩阵
 74     :param gm_input:
 75     :return:
 76     """
 77     a, b, c, d = gm_input.size()
 78     features = gm_input.view(a * b, c * d)
 79     G = torch.mm(features, features.t())
 80 
 81     return G.div(a * b * c * d)
 82 
 83 
 84 class StyleLoss(nn.Module):
 85     """
 86     风格损失
 87     """
 88 
 89     def __init__(self, target_feature):
 90         super(StyleLoss, self).__init__()
 91         self.target = gram_matrix(target_feature).detach()
 92 
 93     def forward(self, fw_input):
 94         G = gram_matrix(fw_input)
 95         self.loss = F.mse_loss(G, self.target)
 96         return fw_input
 97 
 98 
 99 # 使用19层的VGG神经网络模型
100 cnn = models.vgg19(pretrained=True).features.to(device).eval()
101 
102 
103 cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
104 cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
105 
106 
107 class Normalization(nn.Module):
108     """
109     规范化输入图像
110     """
111     def __init__(self, mean, std):
112         super(Normalization, self).__init__()
113         self.mean = mean.clone().detach().view(-1, 1, 1)
114         self.std = std.clone().detach().view(-1, 1, 1)
115 
116     def forward(self, img):
117         return (img - self.mean) / self.std
118 
119 
120 def get_style_model_and_losses(cn, normalization_mean, normalization_std, style_i, content_i, cld, sld):
121     """
122     获取内容损失和风格损失
123     :param cn:
124     :param normalization_mean:
125     :param normalization_std:
126     :param style_i:
127     :param content_i:
128     :param cld:
129     :param sld:
130     :return:
131     """
132 
133     normalization = Normalization(normalization_mean, normalization_std).to(device)
134     content_losses = []
135     style_losses = []
136 
137     model = nn.Sequential(normalization)
138 
139     i = 0
140     for layer in cn.children():
141         if isinstance(layer, nn.Conv2d):
142             i += 1
143             name = 'conv_{}'.format(i)
144         elif isinstance(layer, nn.ReLU):
145             name = 'relu_{}'.format(i)
146             layer = nn.ReLU(inplace=False)
147         elif isinstance(layer, nn.MaxPool2d):
148             name = 'pool_{}'.format(i)
149         elif isinstance(layer, nn.BatchNorm2d):
150             name = 'bn_{}'.format(i)
151         else:
152             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
153 
154         model.add_module(name, layer)
155 
156         if name in cld:
157             target = model(content_i).detach()
158             content_loss = ContentLoss(target)
159             model.add_module("content_loss_{}".format(i), content_loss)
160             content_losses.append(content_loss)
161 
162         if name in sld:
163             target_feature = model(style_i).detach()
164             style_loss = StyleLoss(target_feature)
165             model.add_module("style_loss_{}".format(i), style_loss)
166             style_losses.append(style_loss)
167 
168     for i in range(len(model) - 1, -1, -1):
169         if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
170             break
171 
172     model = model[:(i + 1)]
173 
174     return model, style_losses, content_losses
175 
176 
177 def get_input_optimizer(input_i):
178     """
179     使用 L-BFGS 算法
180     最小化风格、内容的损失
181     :param input_i:
182     :return:
183     """
184     optimizer = optim.LBFGS([input_i])
185     return optimizer
186 
187 
188 def run_style_transfer(cn, norma_mean, normalization_std, ct_img, sl_img, in_img, steps, style_weight, content_weight):
189     """
190     样式转换,建立风格迁移模型
191     :param cn:
192     :param norma_mean:
193     :param normalization_std:
194     :param ct_img:
195     :param sl_img:
196     :param in_img:
197     :param steps:
198     :param style_weight:
199     :param content_weight:
200     :return:
201     """
202     content_layers = ['conv_4']
203     style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
204     model, style_losses, content_losses = get_style_model_and_losses(cn, norma_mean, normalization_std, sl_img, ct_img, content_layers, style_layers)
205     in_img.requires_grad_(True)
206     model.requires_grad_(False)
207 
208     optimizer = get_input_optimizer(in_img)
209     print('Optimizing..')
210     run = [0]
211     while run[0] <= steps:
212 
213         def closure():
214             with torch.no_grad():
215                 in_img.clamp_(0, 1)
216 
217             optimizer.zero_grad()
218             model(in_img)
219             style_score = 0
220             content_score = 0
221 
222             for sl in style_losses:
223                 style_score += sl.loss
224             for cl in content_losses:
225                 content_score += cl.loss
226 
227             style_score *= style_weight
228             content_score *= content_weight
229 
230             loss = style_score + content_score
231             loss.backward()
232 
233             run[0] += 1
234             if run[0] % 50 == 0:
235                 print("run {}:".format(run))
236                 print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.item(), content_score.item()))
237             return style_score + content_score
238 
239         optimizer.step(closure)
240     with torch.no_grad():
241         in_img.clamp_(0, 1)
242     return in_img
243 
244 
245 def style_transfer(content_image_path, style_image_path, image_save_path, run_steps):
246     """
247     风格迁移主入口
248     :param content_image_path: 内容图片
249     :param style_image_path: 风格图片
250     :param image_save_path: 存储图片地址
251     :param run_steps: 执行计算次数
252     :return:
253     """
254     c_image, c_im_h, c_im_w = get_img_size(content_image_path)
255     s_image, s_im_h, s_im_w = get_img_size(style_image_path)
256     content_img = image_loader(c_image, c_im_h, c_im_w)
257     style_img = image_loader(s_image, c_im_h, c_im_w)
258     assert style_img.size() == content_img.size()
259     # 输入内容图像
260     input_img = content_img.clone()
261     begin_time = datetime.datetime.now()
262     print("******************开始时间*****************", begin_time)
263     output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img, run_steps, s_weight, c_weight)
264     try:
265         im_show(output, image_save_path)
266     except Exception as e:
267         print(e)
268     print("******************结束时间*****************", datetime.datetime.now())
269     print("******************耗时*****************", datetime.datetime.now() - begin_time)
270 
271 
272 if __name__ == '__main__':
273     s_weight = 1000000
274     c_weight = 1
275     # content_img_path = "data/drew/img/512.png"
276     content_img_path = "/data/drew/img/dancing.jpg"
277     # style_img_path = "data/drew/img/512r.png"
278     style_img_path = "/data/drew/img/picasso.jpg"
279     for steps in range(100, 3200, 200):
280         # save_path = "data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
281         save_path = "/data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
282         style_transfer(content_img_path, style_img_path, save_path, steps)
View Code

 

上一篇:MarkDown学习


下一篇:别再问用 Go 语言如何对接微信支付了:看看这个包