Commit 2b8bff84 authored by mac's avatar mac

2ptl

parent 0e9dd161
import argparse
import torch
import torchvision
from torchsummary import summary
from torch.utils.mobile_optimizer import optimize_for_mobile
from model import Generator
from PIL import Image
import tensorflow as tf
from torchvision.transforms.functional import to_tensor, to_pil_image
def load_image(image_path, x32=False):
img = Image.open(image_path).convert("RGB")
if x32:
def to_32s(x):
return 256 if x < 256 else x - x % 32
w, h = img.size
img = img.resize((to_32s(w), to_32s(h)))
return img
#model = torchvision.models.mobilenet_v2(pretrained=True)
#model.eval()
#example = torch.rand(1, 3, 224, 224)
#traced_script_module = torch.jit.trace(model, example)
#traced_script_module_optimized = optimize_for_mobile(traced_script_module)
#traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
def transModel2Ptl(args):
print(f"model loaded: {args.checkpoint}")
net = Generator()
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
net.to(args.device).eval()
# print('shape:', net.shape)
print('model:', net)
#example = torch.rand(1, 3, None, None)
example = torch.rand(1, 3, 500, 500)
traced_script_module = torch.jit.trace(net, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(args.output_dir+args.filename)
def testptl(args):
ptl = '/animegan2-pytorch/weights/p.ptl' #args.output_dir+args.filename
print('model path:',ptl)
interpreter = tf.contrib.lite.Interpreter(model_path=ptl)
print(interpreter.get_input_details())
print(interpreter.get_output_details())
print(interpreter.get_tensor_details())
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on input data.
input_shape = input_details[0]['shape']
print(input_shape)
image = load_image(args.image_file, args.x32)
image = np.asarray(image, np.float32)
image = np.expand_dims(image, axis=0)
print('image shape:', image.shape)
interpreter.resize_tensor_input(input_details[0]['index'], image.shape)
interpreter.allocate_tensors()
# Use same image as Keras model
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
output_data = np.squeeze(output_data)
output_data = tf.clip_by_value(output_data, 0, 255)
output_data = tf.round(output_data)
output_data = tf.cast(output_data, tf.uint8)
print('1:',output_data.shape)
output_data = to_pil_image(output_data)
print('2:',output_data.shape)
output_data.save('tt.jpg')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--checkpoint',
type=str,
default='./weights/paprika.pt',
)
parser.add_argument(
'--input_dir',
type=str,
default='./samples/inputs',
)
parser.add_argument(
'--output_dir',
type=str,
default='./weights/',
)
parser.add_argument(
'--filename',
type=str,
default='p.ptl',
)
parser.add_argument(
'--image_file',
type=str,
default='./samples/inputs/1.jpg',
)
parser.add_argument(
'--device',
type=str,
default='cpu', #cuda:0
)
parser.add_argument(
'--upsample_align',
type=bool,
default=False,
help="Align corners in decoder upsampling layers"
)
parser.add_argument(
'--x32',
action="store_true",
help="Resize images to multiple of 32"
)
args = parser.parse_args()
#transModel2Ptl(args)
testptl(args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment