adain / test_style_transfer.py
Chengkai Yang
init
7930ce0
raw
history blame
1.98 kB
import argparse
import os
import numpy as np
from AdaIN import StyleTransferNet
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
class AlphaRange(object):
def __init__(self, start, end):
self.start = start
self.end = end
def __eq__(self, other):
return self.start <= other <= self.end
def __str__(self):
return 'Alpha Range'
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, help='test image')
parser.add_argument('--style_image', type=str, help='style image')
parser.add_argument('--weight', type=str, help='decoder weight file')
parser.add_argument('--alpha', type=float, default=1.0, choices=[AlphaRange(0.0, 1.0)], help='Level of style transfer, value between 0 and 1')
parser.add_argument('--cuda', action='store_true', help='Using GPU to train')
if __name__ == '__main__':
opt =parser.parse_args()
input_image = Image.open(opt.input_image)
style_image = Image.open(opt.style_image)
output_format = opt.input_image[opt.input_image.find('.'):]
out_dir = './results/'
os.makedirs(out_dir, exist_ok=True)
with torch.no_grad():
vgg_model = torch.load('vgg_normalized.pth')
net = StyleTransferNet(vgg_model)
net.decoder.load_state_dict(torch.load(opt.weight))
net.eval()
input_image = transforms.Resize(512)(input_image)
style_image = transforms.Resize(512)(style_image)
input_tensor = transforms.ToTensor()(input_image).unsqueeze(0)
style_tensor = transforms.ToTensor()(style_image).unsqueeze(0)
if torch.cuda.is_available() and opt.cuda:
net.cuda()
input_tensor = input_tensor.cuda()
style_tensor = style_tensor.cuda()
out_tensor = net([input_tensor, style_tensor], alpha = opt.alpha)
save_image(out_tensor, out_dir + opt.input_image[opt.input_image.rfind('/')+1: opt.input_image.find('.')]
+"_style_"+ opt.style_image[opt.style_image.rfind('/')+1: opt.style_image.find('.')]
+ output_format)