Commit 4a9c882a authored by xhlulu's avatar xhlulu

Add known networks

parent 449279b6
import torch
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights"
known = {
name: f"{release_url}/{name}.pt"
for name in [
'face_paint_512_v0', 'face_paint_512_v2'
]
}
from model import Generator
device = torch.device(device)
......@@ -8,10 +16,11 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
model = Generator().to(device)
if type(pretrained) == str:
ckpt_url = pretrained
# Look if a known name is passed, otherwise assume it's a URL
ckpt_url = known.get(pretrained, pretrained)
pretrained = True
else:
ckpt_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights/face_paint_512_v2_0.pt"
ckpt_url = known.get('face_paint_512_v2')
if pretrained is True:
state_dict = torch.hub.load_state_dict_from_url(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment