animegan2-pytorch/test_faces.ipynb

123 lines
1.3 MiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import random\n",
"import numpy as np\n",
"\n",
"from model import Generator\n",
"\n",
"def load_image(path, size=None):\n",
" image = image2tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB))\n",
"\n",
" w, h = image.shape[-2:]\n",
" if w != h:\n",
" crop_size = min(w, h)\n",
" left = (w - crop_size)//2\n",
" right = left + crop_size\n",
" top = (h - crop_size)//2\n",
" bottom = top + crop_size\n",
" image = image[:,:,left:right, top:bottom]\n",
"\n",
" if size is not None and image.shape[-1] != size:\n",
" image = torch.nn.functional.interpolate(image, (size, size), mode=\"bilinear\", align_corners=True)\n",
" \n",
" return image\n",
"\n",
"def image2tensor(image):\n",
" image = torch.FloatTensor(image).permute(2,0,1).unsqueeze(0)/255.\n",
" return (image-0.5)/0.5\n",
"\n",
"def tensor2image(tensor):\n",
" tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1,2,0).cpu().numpy()\n",
" return tensor*0.5 + 0.5\n",
"\n",
"def imshow(img, size=5, cmap='jet'):\n",
" plt.figure(figsize=(size,size))\n",
" plt.imshow(img, cmap=cmap)\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAiMCAYAAACZ/4hNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz92a8lSZ7fiX1+trj72e4WS+7VVd3VXb2xh2zu4lAvA3CAmYEgSHoYQRD0oP9KkAC9jYSBHvQkYIQBxYE01PRwmmyyWb1VdVVWZmRkLPfGXc7ii5n99GDmfvzeiIzMZgvUg8oDJ8459/hq9rOvfX+riaryy+2X2y+3X27/Pjfz/+sb+OX2y+2X2///bb8Enl9uv9x+uf17334JPL/cfrn9cvv3vv0SeH65/XL75fbvffsl8Pxy++X2y+3f++be9+NPurddXu/2gSmqIALjESIA8s6jFUgI1QDLm8j6657Fz7dwPSDRQNJ8vMwOU2HugRM5fpfxzKr3/j7eiI47Jc03qOXU4/HKg2Py+WTa77jv21uadlJN5XYNBgEDwSXiR0u2319z+0FFWHkGb0gynl/Hx5vu892tPm+M48exmRRQ5K0m/zavpYjMzqVj87zddQ9ux6hS9xH5Bom4d+vTR733fX6JfF0hqiIBmtuexU+/oLl6hk8Dxi4RW5V+HA/Q6Rp674zHM+t0weOnqdGE3AnM3tGxR+4/hOqDz8drT286f5ej3KiWzs4dJkbACNFZ+tPH7D76jNvNgq0xKIInYQFrBFMGlTGAyHTb44fjMHu77+WtD99xe8/+Mvv0zT2f5fV3HlffeKb3As/9M+eTPXyY+eCfBvR0rE47qxwH+XsbQkDkeNz7tmnQMAMfkXtHKYKYsodRSGkSwWlPAUnlWaZzPmjWOapOnZze8Vu5L50dh+S/jS+OoCb3OlC/Ad/k7c9T+35zG32XUAlVRcaW0DmIv/fU0/2+bxvHwzQZjXIwb+YyQEWO+zgLdW1xC494j0QQ48COBF2npn/HFR/c4vhEwn0Bnd3/Q9CZA8skLYqUvx+nT2U+Q46Xyu/5N50mLQVTRMXk+1CBRCLESIqaJcFwrz2U+2rJ+1r84dh6uK88+Ns3du97+l7f8ek7nPGt7f3A876TFumchunDm/2me7jHYh7+XaeBMB9f8vBk0zQ/Cs6R5czmNQRzD5zywAeZAEOPbGpEn9kAmK6vMtt/Pmpms4zkaymJ2Uk4Cnu+aZ0o1LcNW46N+k4s/IajvxNgjLcnb31XzWzzwbzxzstMt8gcSN5xnXlfaqZ6ghzBprwnABUsipWESEKMQsqAIPeE5p0C9Pb3aTQe5UpHpiPy1qHjnkcQOj6tvvMZFTQdfyvn1AmFxm6Uo9yWe1DJ8q5JkZSK/GQuMeHVODGNoqDzZ5o1hXx71z+UmPdO799Vju7tp8dG+pZj3w887zpY37XDjG/KuzpoBgfT7KeT4B+H8djkpoxveUso8kAeZ5CxkzmCylzVKp2k92awNIGRlnNoGrkHoGnCCx2lxCh5WKRylAGEJGBEEE35GoUqKwmbTBYso6gkxoEy8i0px89v9Z62Nx/Q78CYhwN+wuLZ7vI+NNDjMQ+3Ua6/ERhHIRd5N+DcQ613HT4DgvHey4lGLSoFRfuAhoAmzX2gpR++8dTzUTAJ2v3PUoZyvmiWj7HlRpxhpm7N5pjj3Rd5npiuFJaus94t550h9NQ3ZYxElKSJFBMaIypCUkMkoSMIST6nMYAWVUtncjEf9PPrPZj7Zo3/zlZ750/fBXzkwZfvOOn9FRjPOy40PutfBR3fOt1E9CfgeCdHLB9Gwbg/qNJsx+PQS6Y0hMb8q+bjU5ltVaTQ2JjP+WDAZKYjIAGMTmJV7gCVhFrQFADBmHz9pKAp5Znd5pnbaIDkZ2PyKNRajhn/IPee/35jvD1r6XwoTG30Hfv//rlE39VF79yEDLpv3dl0cbn3rHNQHUELxvliPrtk1Ti34TjNy33adESI+9ef1CruMcsjyIzXGIFnPilpUQNn7+Ox5Xcx8yeVSdx0POds/xGGMpsu19EZpAkkUSIBTYEUI8lYoigmSwSqBgzYsVFFi7xqFvkCfA+ZUr4vGd/u9dn7QOidcvMtoHUcc3+17a8OPA+vyQOW8T6J/6Z7nCapo/Ulfx2hvRw4yc43XeR+00mZ0RAp+rkcZw2Z3dCou8ts0E4YpogYMEXwVMlUJZ/DmFgYEaDhiCcCGArjKYZnTZAS5UTl0kcWNErJvJkm5nd8qCy89//07sn5O246e5/OO2tK1ftAn9Xh+2N5mt3fmnQLNM4Gae5RnRjTeA5b7HMVgrdgDIgVJB5VZh0nJhltLrPnkAI+84aY3+g4EY33KCMgjgCUCkgU1akcKyKIyfapGbQdAQWKasXsl3zsHMyO7Db3eRRhQAkxoEMPxoJlclJETWXySxnAVLAmt50pIDgBUemDIxPS8e0I2g+3dyDNN7Kf+Y/Md/qrgw78dYFnfg9yZCPfuL0HdHL/j7PWcbDdO+c7WmUybhfwm6ArU5rZBbL9JtsMTKHv2eaSTJ45jByvJuPUoZCsgtV8bFQIgomgdkBdRGwihgTJZCeXMagq0SiDE6I4BmIGpmRA0sTstKhpiiApXzOVAXlUp8aBfzRF3xOMe3anv9o28cPZ9cbv33ZGa2S0hU63MQexORLNiYrMfspjW6b5xYngZcQIPYJ4Rvms1hrNbSX3r3HvfqWA4QPwOQLKyGBjAZpibyGDjhhBTLEGG/JLUlabJX+GiBQbDTGLmKqQOUvmLeM18xx2tFRFIEoiaELjAIPNBvREuQeZQNYV3TeRwceMzFTm/XTfTXG/NWYII2//+q7tneznXTvdu5x+w49vb+8Fnsw48lmnm3jP+YT71O673MLR6j8Hr286gc4OOl7z6PdOmZ2gxTsQmbe0moQSSRYwSvKKVEqyASSRNGYVCbJBU0dXvGZBTAkGg+kMxlhi1aJVj5rIIAGJDqOGVO4vWggWVCI9hhg02y5wqJgikamwKBi5vI7COmsXmXXsUU0ZjcFvexv/qtsRNPKUef/acu/7tJkjQ3t46fm5RsAZPcujCjLuJwUYjMnqmxUQbzOLKCEQMjJIBBljEe7JxYMHn7Oxse0KmIFmuwop21hSys3vDLapMN5kl/dIH6KSUrHjjQitisaExh5NCTEWk33eebIi2wKTjECU2zQVGc/qecKGPa69AelRItEtj0BpLSQlomAly1ISDFpc8iUsYwRKkYnUMd7+vB8munZsnHtq/Tewnwdnedeuxy54L2U6bu8HnhlvHk1w77rLeQdnRM93UUT4yEKmG8x2nXH2yaqDTrQ2o76+w95TaOe9pz3e32RsHs8rSjYWS7bF+EjyAaqAaQS7SFAHjA+IU1LsCGGY3JuQB1walDRETDRwsJAMzlaEqkXqAfWJEBPag1V7tDg5TxAhWKUPQncIJDsgVU0SS7LFGK3FxiMGxGQj4wPBMDOdRIpUjaAwtc49N35plwdq0lt9PB+/IpPsyGyHdx2vwEGPrl4RsOM9zPtmpvrcY0Hj/3LcZXxuIbOpBBBLrJBGUJONuNP59HidabTNqNS0z3G8IQnVRJREjBE1il15qkWN8RaxUg4fb6rEqKkSw0DoO9r9ntdvXnM47DEGKueoK0dTeZqmwlqX2UuGiPttrUJCSSZlcNIeGbYYOlzqsbWANagxgCeKRYqqFVLW8g2KsQY15ftD4CmyMmLnyCZFi2oreVJ9iyjcb7p39vm7Pn8XUvJwey/w3KR880bLzIfO1ORj/IeUVj3Ou3mn0ZeT/zIXhtmeU4BV3teouffkMlcj5OgLU02TCMvkgpeRQkGZPZVEcoHUKLIBswIqUN8TfQd+IEmHpoEYW3rpGEyi04Htdsd+tyMMAYvBqcfHhkoW1NUCV0eW5w2mFjQMpB5iD5oSCUPCMaiQTCB2ieRaVA9oswRfkawD8VkgDcfZcTYAx5ZKc6AXg5hs9zDlPb/s2CDMDa3ftCn3Be/Yj0d4mMjCRFJy+0aFP9kdyacTqAz4PG5wgC1gZCUDlCX3iSnfp4ExuyE
"text/plain": [
"<Figure size 2880x2880 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = 'cpu'\n",
"torch.set_grad_enabled(False)\n",
"image_size = 300 # Can be tuned, works best when the face width is between 200~250 px\n",
"\n",
"model = Generator().eval().to(device)\n",
"\n",
"ckpt = torch.load(f\"checkpoint/generator_celeba_distill.pt\", map_location=device)\n",
"model.load_state_dict(ckpt)\n",
" \n",
"results = []\n",
"for j in range(1,17):\n",
" image = load_image(f\"samples/faces/{j}.jpg\", image_size)\n",
" output = model(image.to(device))\n",
"\n",
" results.append(torch.cat([image, output.cpu()], 3))\n",
"results = torch.cat(results, 2)\n",
"\n",
"imshow(tensor2image(results),40)\n",
"cv2.imwrite('./samples/face_results.jpg', cv2.cvtColor(255*tensor2image(results), cv2.COLOR_BGR2RGB))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch171",
"language": "python",
"name": "torch171"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}