Commit b5bc801b authored by bryandlee's avatar bryandlee

Refactor: Remove cv2 depedency

parent f11e0567
import os
import argparse import argparse
import torch from PIL import Image
import cv2
import numpy as np import numpy as np
import os
import torch
from torchvision.transforms.functional import to_tensor, to_pil_image
from model import Generator from model import Generator
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
def load_image(image_path, x32=False): def load_image(image_path, x32=False):
img = cv2.imread(image_path).astype(np.float32) img = Image.open(image_path).convert("RGB")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
if x32: # resize image to multiple of 32s if x32:
def to_32s(x): def to_32s(x):
return 256 if x < 256 else x - x%32 return 256 if x < 256 else x - x % 32
img = cv2.resize(img, (to_32s(w), to_32s(h))) w, h = img.size
img = img.resize((to_32s(w), to_32s(h)))
img = torch.from_numpy(img)
img = img/127.5 - 1.0
return img return img
...@@ -43,22 +44,22 @@ def test(args): ...@@ -43,22 +44,22 @@ def test(args):
image = load_image(os.path.join(args.input_dir, image_name), args.x32) image = load_image(os.path.join(args.input_dir, image_name), args.x32)
with torch.no_grad(): with torch.no_grad():
input = image.permute(2, 0, 1).unsqueeze(0).to(device) image = to_tensor(image).unsqueeze(0) * 2 - 1
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy() out = net(image.to(device), args.upsample_align).cpu()
out = (out + 1)*127.5 out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
out = np.clip(out, 0, 255).astype(np.uint8) out = to_pil_image(out)
cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB)) out.save(os.path.join(args.output_dir, image_name))
print(f"image saved: {image_name}") print(f"image saved: {image_name}")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--checkpoint', '--checkpoint',
type=str, type=str,
default='./pytorch_generator_Paprika.pt', default='./weights/paprika.pt',
) )
parser.add_argument( parser.add_argument(
'--input_dir', '--input_dir',
...@@ -79,12 +80,13 @@ if __name__ == '__main__': ...@@ -79,12 +80,13 @@ if __name__ == '__main__':
'--upsample_align', '--upsample_align',
type=bool, type=bool,
default=False, default=False,
help="Align corners in decoder upsampling layers"
) )
parser.add_argument( parser.add_argument(
'--x32', '--x32',
action="store_true", action="store_true",
help="Resize images to multiple of 32"
) )
args = parser.parse_args() args = parser.parse_args()
test(args) test(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment