From 3a8694b642413417a8f3766f4fc051f3c3a585b4 Mon Sep 17 00:00:00 2001 From: Yudi Date: Tue, 27 Aug 2024 14:12:36 +0800 Subject: [PATCH 1/2] add qwen prompt refiner --- diffsynth/configs/model_config.py | 11 +++ diffsynth/prompters/prompt_refiners.py | 69 ++++++++++++++++--- .../image_synthesis/qwen_prompt_refining.py | 30 ++++++++ 3 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 examples/image_synthesis/qwen_prompt_refining.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 5216aaf..9c079a6 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -209,6 +209,17 @@ preset_models_on_modelscope = { "RIFE": [ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"), ], + # Qwen Prompt + "QwenPrompt": [ + ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ], # Beautiful Prompt "BeautifulPrompt": [ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), diff --git a/diffsynth/prompters/prompt_refiners.py b/diffsynth/prompters/prompt_refiners.py index 6d8f0df..dc9fd16 100644 --- a/diffsynth/prompters/prompt_refiners.py +++ b/diffsynth/prompters/prompt_refiners.py @@ -3,18 +3,19 @@ from ..models.model_manager import ModelManager import torch - -class BeautifulPrompt(torch.nn.Module): +class QwenPrompt(torch.nn.Modile): + # This class leverages the open-source Qwen model to translate Chinese prompts into English, + # with an integrated optimization mechanism for enhanced translation quality. def __init__(self, tokenizer_path=None, model=None, template=""): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.model = model self.template = template - @staticmethod def from_model_manager(model_nameger: ModelManager): - model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True) + model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True) + template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' if model_path.endswith("v2"): template = """Converts a simple image description into a prompt. \ @@ -22,13 +23,12 @@ Prompts are formatted as multiple related tags separated by commas, plus you can or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ but make sure there is a correlation between the input and output.\n\ ### Input: {raw_prompt}\n### Output:""" - beautiful_prompt = BeautifulPrompt( + beautiful_prompt = QwenPrompt( tokenizer_path=model_path, model=model, template=template ) - return beautiful_prompt - + return qwen_prompt def __call__(self, raw_prompt, positive=True, **kwargs): if positive: @@ -48,7 +48,60 @@ but make sure there is a correlation between the input and output.\n\ outputs[:, input_ids.size(1):], skip_special_tokens=True )[0].strip() - print(f"Your prompt is refined by BeautifulPrompt: {prompt}") + print(f"Your prompt is refined by Qwen : {prompt}") + return prompt + else: + return raw_prompt + + +class BeautifulPrompt(torch.nn.Module): + def __init__(self, tokenizer_path=None, model=None, template=""): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = model + self.template = template + + + @staticmethod + def from_model_manager(model_nameger: ModelManager): + model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True) + system_prompt = """你是一个英文图片描述家,你看到一段中文图片描述后,尽可能用精简准确的英文,将中文的图片描述的意境用英文短句展示出来,并附带图片风格描述,如果中文描述中没有明确的风格,你需要根据中文意境额外添加一些风格描述,确保图片中的内容丰富生动。\n\n你有如下几种不同的风格描述示例进行参考:\n\n特写风格: Extreme close-up by Oliver Dum, magnified view of a dewdrop on a spider web occupying the frame, the camera focuses closely on the object with the background blurred. The image is lit with natural sunlight, enhancing the vivid textures and contrasting colors.\n\n复古风格: Photograph of women working, Daguerreotype, calotype, tintype, collodion, ambrotype, carte-de-visite, gelatin silver, dry plate, wet plate, stereoscope, albumen print, cyanotype, glass, lantern slide, camera \n\n动漫风格: a happy dairy cow just finished grazing, in the style of cartoon realism, disney animation, hyper-realistic portraits, 32k uhd, cute cartoonish designs, wallpaper, luminous brushwork \n\n普通人物场景风格: A candid shot of young best friends dirty, at the skatepark, natural afternoon light, Canon EOS R5, 100mm, F 1.2 aperture setting capturing a moment, cinematic \n\n景观风格: bright beautiful sunrise over the sea and rocky mountains, photorealistic, \n\n设计风格: lionface circle tshirt design, in the style of detailed botanical illustrations, colorful cartoon, exotic atmosphere, 2d game art, white background, contour \n\n动漫风格: Futuristic mecha robot walking through a neon cityscape, with lens flares, dramatic lighting, illustrated like a Gundam anime poster \n\n都市风格: warmly lit room with large monitors on the clean desk, overlooking the city, ultrareal and photorealistic, \n\n\n请根据上述图片风格,以及中文描述生成对应的英文图片描述 \n\n 请注意:\n\n 如果中文为成语或古诗,不能只根据表层含义来进行描述,而要描述其中的意境!例如:“胸有成竹”的图片场景中并没有竹子,而是描述一个人非常自信的场景,请在英文翻译中不要提到bamboo,以此类推\n\n字数不超过100字""" + messages = [{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': 'user', + 'content': "{raw_prompt}" + }] + + qwen_prompt = QwenPrompt( + tokenizer_path=model_path, + model=model, + template=template + ) + return qwen_prompt + + + def __call__(self, raw_prompt, positive=True, **kwargs): + if positive: + model_input = self.template.format(raw_prompt=raw_prompt) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(device) + + generated_ids = self.model.generate( + model_inputs.input_ids, + max_new_tokens=512 + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + prompt = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + print(f"Your prompt is refined by Qwen: {prompt}, RawPrompt: {raw_prompt}") return prompt else: return raw_prompt diff --git a/examples/image_synthesis/qwen_prompt_refining.py b/examples/image_synthesis/qwen_prompt_refining.py new file mode 100644 index 0000000..963e0b5 --- /dev/null +++ b/examples/image_synthesis/qwen_prompt_refining.py @@ -0,0 +1,30 @@ +from diffsynth import ModelManager, SDXLImagePipeline, download_models, QwenPrompt +import torch + + +# Download models (automatically) +# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors) +# `models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/`: [link](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) +# `models/translator/opus-mt-zh-en/`: [link](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) +download_models(["StableDiffusionXL_v1", "QwenPrompt", "opus-mt-zh-en"]) + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models([ + "models/stable_diffusion_xl/sd_xl_base_1.0.safetensors", + "models/QwenPrompt/qwen2-1.5b-instruct", + "models/translator/opus-mt-zh-en" +]) +pipe = SDXLImagePipeline.from_model_manager(model_manager, prompt_refiner_classes=[QwenPrompt]) + +prompt = "一个漂亮的女孩" +negative_prompt = "" + +for seed in range(4): + torch.manual_seed(seed) + image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + height=1024, width=1024, + num_inference_steps=30 + ) + image.save(f"{seed}.jpg") From e5e55345dcaa61e71290947476b162a1ab2fadf2 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 4 Sep 2024 17:12:01 +0800 Subject: [PATCH 2/2] support qwen prompt refiner --- diffsynth/configs/model_config.py | 2 + diffsynth/prompters/__init__.py | 2 +- diffsynth/prompters/prompt_refiners.py | 57 ++++++++++--------- .../image_synthesis/qwen_prompt_refining.py | 7 +-- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9c079a6..3e492d5 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -78,6 +78,7 @@ huggingface_model_loader_configs = [ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None), ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None), ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None), + ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None), ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ] patch_model_loader_configs = [ @@ -307,4 +308,5 @@ Preset_model_id: TypeAlias = Literal[ "ControlNet_union_sdxl_promax", "FLUX.1-dev", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", + "QwenPrompt", ] \ No newline at end of file diff --git a/diffsynth/prompters/__init__.py b/diffsynth/prompters/__init__.py index 530eece..4c6a20a 100644 --- a/diffsynth/prompters/__init__.py +++ b/diffsynth/prompters/__init__.py @@ -1,4 +1,4 @@ -from .prompt_refiners import Translator, BeautifulPrompt +from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt from .sd_prompter import SDPrompter from .sdxl_prompter import SDXLPrompter from .sd3_prompter import SD3Prompter diff --git a/diffsynth/prompters/prompt_refiners.py b/diffsynth/prompters/prompt_refiners.py index dc9fd16..4ba469a 100644 --- a/diffsynth/prompters/prompt_refiners.py +++ b/diffsynth/prompters/prompt_refiners.py @@ -3,19 +3,18 @@ from ..models.model_manager import ModelManager import torch -class QwenPrompt(torch.nn.Modile): - # This class leverages the open-source Qwen model to translate Chinese prompts into English, - # with an integrated optimization mechanism for enhanced translation quality. + +class BeautifulPrompt(torch.nn.Module): def __init__(self, tokenizer_path=None, model=None, template=""): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.model = model self.template = template - @staticmethod - def from_model_manager(model_nameger: ModelManager): - model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True) + @staticmethod + def from_model_manager(model_manager: ModelManager): + model, model_path = model_manager.fetch_model("beautiful_prompt", require_model_path=True) template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' if model_path.endswith("v2"): template = """Converts a simple image description into a prompt. \ @@ -23,12 +22,13 @@ Prompts are formatted as multiple related tags separated by commas, plus you can or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ but make sure there is a correlation between the input and output.\n\ ### Input: {raw_prompt}\n### Output:""" - beautiful_prompt = QwenPrompt( + beautiful_prompt = BeautifulPrompt( tokenizer_path=model_path, model=model, template=template ) - return qwen_prompt + return beautiful_prompt + def __call__(self, raw_prompt, positive=True, **kwargs): if positive: @@ -48,49 +48,50 @@ but make sure there is a correlation between the input and output.\n\ outputs[:, input_ids.size(1):], skip_special_tokens=True )[0].strip() - print(f"Your prompt is refined by Qwen : {prompt}") + print(f"Your prompt is refined by BeautifulPrompt: {prompt}") return prompt else: return raw_prompt -class BeautifulPrompt(torch.nn.Module): - def __init__(self, tokenizer_path=None, model=None, template=""): + +class QwenPrompt(torch.nn.Module): + # This class leverages the open-source Qwen model to translate Chinese prompts into English, + # with an integrated optimization mechanism for enhanced translation quality. + def __init__(self, tokenizer_path=None, model=None, system_prompt=""): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.model = model - self.template = template + self.system_prompt = system_prompt @staticmethod def from_model_manager(model_nameger: ModelManager): - model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True) + model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True) system_prompt = """你是一个英文图片描述家,你看到一段中文图片描述后,尽可能用精简准确的英文,将中文的图片描述的意境用英文短句展示出来,并附带图片风格描述,如果中文描述中没有明确的风格,你需要根据中文意境额外添加一些风格描述,确保图片中的内容丰富生动。\n\n你有如下几种不同的风格描述示例进行参考:\n\n特写风格: Extreme close-up by Oliver Dum, magnified view of a dewdrop on a spider web occupying the frame, the camera focuses closely on the object with the background blurred. The image is lit with natural sunlight, enhancing the vivid textures and contrasting colors.\n\n复古风格: Photograph of women working, Daguerreotype, calotype, tintype, collodion, ambrotype, carte-de-visite, gelatin silver, dry plate, wet plate, stereoscope, albumen print, cyanotype, glass, lantern slide, camera \n\n动漫风格: a happy dairy cow just finished grazing, in the style of cartoon realism, disney animation, hyper-realistic portraits, 32k uhd, cute cartoonish designs, wallpaper, luminous brushwork \n\n普通人物场景风格: A candid shot of young best friends dirty, at the skatepark, natural afternoon light, Canon EOS R5, 100mm, F 1.2 aperture setting capturing a moment, cinematic \n\n景观风格: bright beautiful sunrise over the sea and rocky mountains, photorealistic, \n\n设计风格: lionface circle tshirt design, in the style of detailed botanical illustrations, colorful cartoon, exotic atmosphere, 2d game art, white background, contour \n\n动漫风格: Futuristic mecha robot walking through a neon cityscape, with lens flares, dramatic lighting, illustrated like a Gundam anime poster \n\n都市风格: warmly lit room with large monitors on the clean desk, overlooking the city, ultrareal and photorealistic, \n\n\n请根据上述图片风格,以及中文描述生成对应的英文图片描述 \n\n 请注意:\n\n 如果中文为成语或古诗,不能只根据表层含义来进行描述,而要描述其中的意境!例如:“胸有成竹”的图片场景中并没有竹子,而是描述一个人非常自信的场景,请在英文翻译中不要提到bamboo,以此类推\n\n字数不超过100字""" - messages = [{ - 'role': 'system', - 'content': system_prompt - }, { - 'role': 'user', - 'content': "{raw_prompt}" - }] - qwen_prompt = QwenPrompt( tokenizer_path=model_path, model=model, - template=template + system_prompt=system_prompt ) return qwen_prompt - + def __call__(self, raw_prompt, positive=True, **kwargs): if positive: - model_input = self.template.format(raw_prompt=raw_prompt) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': raw_prompt + }] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - model_inputs = self.tokenizer([text], return_tensors="pt").to(device) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) generated_ids = self.model.generate( model_inputs.input_ids, @@ -100,12 +101,12 @@ class BeautifulPrompt(torch.nn.Module): output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] - prompt = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] - print(f"Your prompt is refined by Qwen: {prompt}, RawPrompt: {raw_prompt}") + prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + print(f"Your prompt is refined by Qwen: {prompt}") return prompt else: return raw_prompt - + class Translator(torch.nn.Module): diff --git a/examples/image_synthesis/qwen_prompt_refining.py b/examples/image_synthesis/qwen_prompt_refining.py index 963e0b5..1cf1239 100644 --- a/examples/image_synthesis/qwen_prompt_refining.py +++ b/examples/image_synthesis/qwen_prompt_refining.py @@ -2,18 +2,13 @@ from diffsynth import ModelManager, SDXLImagePipeline, download_models, QwenProm import torch -# Download models (automatically) -# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors) -# `models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd/`: [link](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) -# `models/translator/opus-mt-zh-en/`: [link](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) -download_models(["StableDiffusionXL_v1", "QwenPrompt", "opus-mt-zh-en"]) +download_models(["StableDiffusionXL_v1", "QwenPrompt"]) # Load models model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") model_manager.load_models([ "models/stable_diffusion_xl/sd_xl_base_1.0.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct", - "models/translator/opus-mt-zh-en" ]) pipe = SDXLImagePipeline.from_model_manager(model_manager, prompt_refiner_classes=[QwenPrompt])