| import argparse | |
| import os | |
| import numpy as np | |
| from AdaIN import StyleTransferNet | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import save_image | |
| class AlphaRange(object): | |
| def __init__(self, start, end): | |
| self.start = start | |
| self.end = end | |
| def __eq__(self, other): | |
| return self.start <= other <= self.end | |
| def __str__(self): | |
| return 'Alpha Range' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input_image', type=str, help='test image') | |
| parser.add_argument('--style_image', type=str, help='style image') | |
| parser.add_argument('--weight', type=str, help='decoder weight file') | |
| parser.add_argument('--alpha', type=float, default=1.0, choices=[AlphaRange(0.0, 1.0)], help='Level of style transfer, value between 0 and 1') | |
| parser.add_argument('--cuda', action='store_true', help='Using GPU to train') | |
| if __name__ == '__main__': | |
| opt =parser.parse_args() | |
| input_image = Image.open(opt.input_image) | |
| style_image = Image.open(opt.style_image) | |
| output_format = opt.input_image[opt.input_image.find('.'):] | |
| out_dir = './results/' | |
| os.makedirs(out_dir, exist_ok=True) | |
| with torch.no_grad(): | |
| vgg_model = torch.load('vgg_normalized.pth') | |
| net = StyleTransferNet(vgg_model) | |
| net.decoder.load_state_dict(torch.load(opt.weight)) | |
| net.eval() | |
| input_image = transforms.Resize(512)(input_image) | |
| style_image = transforms.Resize(512)(style_image) | |
| input_tensor = transforms.ToTensor()(input_image).unsqueeze(0) | |
| style_tensor = transforms.ToTensor()(style_image).unsqueeze(0) | |
| if torch.cuda.is_available() and opt.cuda: | |
| net.cuda() | |
| input_tensor = input_tensor.cuda() | |
| style_tensor = style_tensor.cuda() | |
| out_tensor = net([input_tensor, style_tensor], alpha = opt.alpha) | |
| save_image(out_tensor, out_dir + opt.input_image[opt.input_image.rfind('/')+1: opt.input_image.find('.')] | |
| +"_style_"+ opt.style_image[opt.style_image.rfind('/')+1: opt.style_image.find('.')] | |
| + output_format) |