Unverified Commit 88f75498 authored by bryandlee's avatar bryandlee Committed by GitHub

Merge pull request #28 from xhlulu/main

Fix uploader issues in colab_demo.ipynb
parents cd65ad5c 927c6c14
...@@ -74,6 +74,8 @@ ...@@ -74,6 +74,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Anime FaceGAN Colab app\n",
"\n",
"from io import BytesIO\n", "from io import BytesIO\n",
"import torch\n", "import torch\n",
"from PIL import Image\n", "from PIL import Image\n",
...@@ -85,38 +87,30 @@ ...@@ -85,38 +87,30 @@
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = torch.hub.load(\"bryandlee/animegan2-pytorch:main\", \"generator\", device=device).eval()\n", "model = torch.hub.load(\"bryandlee/animegan2-pytorch:main\", \"generator\", device=device).eval()\n",
"face2paint = torch.hub.load(\"bryandlee/animegan2-pytorch:main\", \"face2paint\", device=device)\n", "face2paint = torch.hub.load(\"bryandlee/animegan2-pytorch:main\", \"face2paint\", device=device)\n",
"\n", "image_format = \"png\" #@param [\"jpeg\", \"png\"]\n",
"#@title Anime FaceGAN Colab app\n",
"image_format = \"jpeg\" #@param [\"jpeg\", \"png\"]\n",
"\n",
"\n", "\n",
"button = widgets.Button(description=\"Start\")\n", "button = widgets.Button(description=\"Start\")\n",
"output = widgets.Output()\n", "output = widgets.Output()\n",
"\n", "\n",
"uploader = widgets.FileUpload(\n",
" accept='image/*', # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'\n",
" multiple=False # True to accept multiple files upload else False\n",
")\n",
"\n", "\n",
"def run(b):\n", "def run(b):\n",
" button.disabled = True\n", " button.disabled = True\n",
"\n", "\n",
" with output:\n", " with output:\n",
" display.clear_output()\n", " display.clear_output()\n",
" \n",
" uploaded = files.upload()\n",
"\n", "\n",
" for fname in uploader.value:\n", " for fname in uploaded:\n",
" bytes_in = uploader.value[fname]['content']\n", " bytes_in = uploaded[fname]\n",
"\n", "\n",
" im_in = Image.open(BytesIO(bytes_in)).convert(\"RGB\")\n", " im_in = Image.open(BytesIO(bytes_in)).convert(\"RGB\")\n",
" im_out = face2paint(model, im_in, side_by_side=False)\n", " im_out = face2paint(model, im_in, side_by_side=False)\n",
"\n",
" buffer_out = BytesIO()\n", " buffer_out = BytesIO()\n",
" im_out.save(buffer_out, format=image_format)\n", " im_out.save(buffer_out, format=image_format)\n",
"\n",
" bytes_out = buffer_out.getvalue()\n", " bytes_out = buffer_out.getvalue()\n",
" wi1 = widgets.Image(value=bytes_in, format=image_format)\n", " wi1 = widgets.Image(value=bytes_in, format=image_format)\n",
" wi2 = widgets.Image(value=bytes_out, format=image_format)\n", " wi2 = widgets.Image(value=bytes_out, format=image_format)\n",
"\n",
" wi1.layout.max_width = '500px'\n", " wi1.layout.max_width = '500px'\n",
" wi1.layout.max_height = '500px'\n", " wi1.layout.max_height = '500px'\n",
" wi2.layout.max_width = '500px'\n", " wi2.layout.max_width = '500px'\n",
...@@ -131,7 +125,7 @@ ...@@ -131,7 +125,7 @@
" button.disabled = False\n", " button.disabled = False\n",
"\n", "\n",
"button.on_click(run)\n", "button.on_click(run)\n",
"display.display(uploader, button, output)" "display.display(button, output)"
] ]
} }
], ],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment