1 damaged_images = ground_truths * masks + (1 - masks)
2 outputs_comps = ground_truths * masks + outputs * (1 - masks)
3 if count % 1000 == 0:
4 sizes = ground_truths.size()
5 bound = min(5, sizes[0])
6 save_images = torch.Tensor(sizes[0] * 8, sizes[1], sizes[2], sizes[3])
7 save_images_5 = torch.Tensor(sizes[0] * 5, sizes[1], sizes[2], sizes[3])
8 for i in range(sizes[0]):
9 save_images_5[5 * i] = 1 - masks[i]
10 save_images_5[5 * i + 1] = damaged_images[i]
11 save_images_5[5 * i + 2] = outputs[i]
12 save_images_5[5 * i + 3] = outputs_comps[i]
13 save_images_5[5 * i + 4] = ground_truths[i]
14
15 save_image(save_images_5, os.path.join(args.sample_path, '{:05d}.png'.format(count)), nrow=5)
16 def save_image(tensor, filename, nrow=8, padding=2,
17 normalize=False, range=None, scale_each=False, pad_value=0):
18 """Save a given Tensor into an image file.
19
20 Args:
21 tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
22 saves the tensor as a grid of images by calling ``make_grid``.
23 **kwargs: Other arguments are documented in ``make_grid``.
24 """
25 from PIL import Image
26 grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
27 normalize=normalize, range=range, scale_each=scale_each)
28 # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
29 ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
30 im = Image.fromarray(ndarr)
31 im.save(filename)
32
View Code