{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generating images with MX-Font model from a reference style\n", "In this example we'll generate images with trained MX-Font model from a reference style.\n", "If you want to generate multiple styles, please check using `eval.py` instead of using this example file (because it is much simpler to load the referece styles)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Loading packages\n", "* First, load the packages used in this code.\n", "* All of the packages are avilable in `pip`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "from PIL import Image\n", "\n", "import torch\n", "from sconf import Config\n", "from torchvision import transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* These modules are defined in this repository." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import models\n", "from datasets import read_font, render\n", "from utils import save_tensor_to_image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Build model\n", "* Build and load the trained model.\n", "* `weight_path` : \n", " * The location of the trained model weight." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "########################################################\n", "weight_path = \"generator.pth\" # path to weight to infer\n", "########################################################\n", "\n", "cfg = Config(\"cfgs/eval.yaml\", default=\"cfgs/defaults.yaml\")\n", "transform = transforms.Compose(\n", " [transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]\n", ")\n", "decomposition = json.load(open(\"data/chn_decomposition.json\"))\n", "\n", "g_kwargs = cfg.get('g_args', {})\n", "gen = models.Generator(1, cfg.C, 1, **g_kwargs).cuda().eval()\n", "weight = torch.load(weight_path)\n", "if \"generator_ema\" in weight:\n", " weight = weight[\"generator_ema\"]\n", "gen.load_state_dict(weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Load reference images.\n", "* `ref_path`: \n", " * The path of reference font or images.\n", " * If you are using a ttf file, set this to the location of the ttf file.\n", " * If you want to use rendered images, set this to the path to the directory which contains the reference images.\n", "* `extension`:\n", " * If you are using image files, set this to their extension(png, jpg, etc..). \n", " * This will be ignored if `use_ttf` is True.\n", "* `batch_size`:\n", " * The number of images inferred at once." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "########################################################\n", "ref_path = \"data/images/test/ZCOOLKuaiLe-Regular\" # Path to the reference images\n", "extension = \"png\" # Extension of the reference images\n", "batch_size = 3 # The batch size\n", "########################################################\n", "\n", "ref_paths = Path(ref_path).glob(f\"*.{extension}\")\n", "ref_imgs = torch.stack([transform(Image.open(str(p))) for p in ref_paths]).cuda()\n", "ref_batches = torch.split(ref_imgs, batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Extract style factors from reference images." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "style_facts = {}\n", "\n", "for batch in ref_batches:\n", " style_fact = gen.factorize(gen.encode(batch), 0)\n", " for k in style_fact:\n", " style_facts.setdefault(k, []).append(style_fact[k])\n", " \n", "style_facts = {k: torch.cat(v).mean(0, keepdim=True) for k, v in style_facts.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Generate the images.\n", "* `gen_chars`: The characters to generate.\n", "* `save_dir`: Path to save the generated images.\n", "* `source_path`: Path to the font file used as the source font." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "########################################################\n", "gen_chars = \"意味岸\" # Characters to generate\n", "save_dir = Path(\"./results\") # Directory where you want to save generated images\n", "source_path = \"data/chn_source.ttf\" # Path to the font file to render the source images\n", "########################################################\n", "\n", "save_dir.mkdir(parents=True, exist_ok=True)\n", "\n", "source_font = read_font(source_path)\n", "for char in gen_chars:\n", " source_img = transform(render(source_font, char)).unsqueeze(0).cuda()\n", " char_facts = gen.factorize(gen.encode(source_img), 1)\n", " \n", " gen_feats = gen.defactorize([style_facts, char_facts])\n", " out = gen.decode(gen_feats).detach().cpu()[0]\n", "\n", " path = save_dir / f\"{char}.png\"\n", " save_tensor_to_image(out, path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }