Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
A
animegan2-pytorch
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
git
animegan2-pytorch
Commits
2b8bff84
Commit
2b8bff84
authored
Apr 04, 2023
by
mac
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
2ptl
parent
0e9dd161
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
142 additions
and
0 deletions
+142
-0
convert_ptl.py
convert_ptl.py
+142
-0
No files found.
convert_ptl.py
0 → 100644
View file @
2b8bff84
import
argparse
import
torch
import
torchvision
from
torchsummary
import
summary
from
torch.utils.mobile_optimizer
import
optimize_for_mobile
from
model
import
Generator
from
PIL
import
Image
import
tensorflow
as
tf
from
torchvision.transforms.functional
import
to_tensor
,
to_pil_image
def
load_image
(
image_path
,
x32
=
False
):
img
=
Image
.
open
(
image_path
)
.
convert
(
"RGB"
)
if
x32
:
def
to_32s
(
x
):
return
256
if
x
<
256
else
x
-
x
%
32
w
,
h
=
img
.
size
img
=
img
.
resize
((
to_32s
(
w
),
to_32s
(
h
)))
return
img
#model = torchvision.models.mobilenet_v2(pretrained=True)
#model.eval()
#example = torch.rand(1, 3, 224, 224)
#traced_script_module = torch.jit.trace(model, example)
#traced_script_module_optimized = optimize_for_mobile(traced_script_module)
#traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
def
transModel2Ptl
(
args
):
print
(
f
"model loaded: {args.checkpoint}"
)
net
=
Generator
()
net
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
,
map_location
=
"cpu"
))
net
.
to
(
args
.
device
)
.
eval
()
# print('shape:', net.shape)
print
(
'model:'
,
net
)
#example = torch.rand(1, 3, None, None)
example
=
torch
.
rand
(
1
,
3
,
500
,
500
)
traced_script_module
=
torch
.
jit
.
trace
(
net
,
example
)
traced_script_module_optimized
=
optimize_for_mobile
(
traced_script_module
)
traced_script_module_optimized
.
_save_for_lite_interpreter
(
args
.
output_dir
+
args
.
filename
)
def
testptl
(
args
):
ptl
=
'/animegan2-pytorch/weights/p.ptl'
#args.output_dir+args.filename
print
(
'model path:'
,
ptl
)
interpreter
=
tf
.
contrib
.
lite
.
Interpreter
(
model_path
=
ptl
)
print
(
interpreter
.
get_input_details
())
print
(
interpreter
.
get_output_details
())
print
(
interpreter
.
get_tensor_details
())
interpreter
.
allocate_tensors
()
# Get input and output tensors.
input_details
=
interpreter
.
get_input_details
()
output_details
=
interpreter
.
get_output_details
()
# Test the model on input data.
input_shape
=
input_details
[
0
][
'shape'
]
print
(
input_shape
)
image
=
load_image
(
args
.
image_file
,
args
.
x32
)
image
=
np
.
asarray
(
image
,
np
.
float32
)
image
=
np
.
expand_dims
(
image
,
axis
=
0
)
print
(
'image shape:'
,
image
.
shape
)
interpreter
.
resize_tensor_input
(
input_details
[
0
][
'index'
],
image
.
shape
)
interpreter
.
allocate_tensors
()
# Use same image as Keras model
interpreter
.
set_tensor
(
input_details
[
0
][
'index'
],
image
)
interpreter
.
invoke
()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
print
(
output_data
)
output_data
=
np
.
squeeze
(
output_data
)
output_data
=
tf
.
clip_by_value
(
output_data
,
0
,
255
)
output_data
=
tf
.
round
(
output_data
)
output_data
=
tf
.
cast
(
output_data
,
tf
.
uint8
)
print
(
'1:'
,
output_data
.
shape
)
output_data
=
to_pil_image
(
output_data
)
print
(
'2:'
,
output_data
.
shape
)
output_data
.
save
(
'tt.jpg'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
'./weights/paprika.pt'
,
)
parser
.
add_argument
(
'--input_dir'
,
type
=
str
,
default
=
'./samples/inputs'
,
)
parser
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'./weights/'
,
)
parser
.
add_argument
(
'--filename'
,
type
=
str
,
default
=
'p.ptl'
,
)
parser
.
add_argument
(
'--image_file'
,
type
=
str
,
default
=
'./samples/inputs/1.jpg'
,
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cpu'
,
#cuda:0
)
parser
.
add_argument
(
'--upsample_align'
,
type
=
bool
,
default
=
False
,
help
=
"Align corners in decoder upsampling layers"
)
parser
.
add_argument
(
'--x32'
,
action
=
"store_true"
,
help
=
"Resize images to multiple of 32"
)
args
=
parser
.
parse_args
()
#transModel2Ptl(args)
testptl
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment