From d70cd04b15984dddbae123998b845e97beb63d20 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 4 Sep 2024 12:48:32 +0800 Subject: [PATCH] fix bugs --- diffsynth/configs/model_config.py | 22 +++++++++---------- diffsynth/pipelines/base.py | 9 ++++++++ diffsynth/pipelines/flux_image.py | 11 +++------- diffsynth/prompters/__init__.py | 1 + diffsynth/prompters/base_prompter.py | 6 ++--- diffsynth/prompters/omost.py | 12 +--------- .../omost_flux_text_to_image.py | 18 +++++++-------- 7 files changed, 36 insertions(+), 43 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index f738e8d..b5d3501 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -78,8 +78,7 @@ huggingface_model_loader_configs = [ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None), ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None), ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None), - # ("AutoModelForCausalLM", "transformers","omost_prompt",None), - ("LlamaForCausalLM", "transformers.models.llama.modeling_llama","omost_prompt",None), + ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None), ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ] patch_model_loader_configs = [ @@ -222,15 +221,14 @@ preset_models_on_modelscope = { ], # Omost prompt "OmostPrompt":[ - ("Omost/omost-llama-3-8b-4bits","model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","configuration.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits","special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ], # Translator @@ -311,5 +309,5 @@ Preset_model_id: TypeAlias = Literal[ "ControlNet_union_sdxl_promax", "FLUX.1-dev", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", - "OmostPrompt" + "OmostPrompt", ] \ No newline at end of file diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 78e66b5..2feb405 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -50,4 +50,13 @@ class BasePipeline(torch.nn.Module): noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) return noise_pred + + + def extend_prompt(self, prompt, local_prompts, masks, mask_scales): + extended_prompt_dict = self.prompter.extend_prompt(prompt) + prompt = extended_prompt_dict.get("prompt", prompt) + local_prompts += extended_prompt_dict.get("prompts", []) + masks += extended_prompt_dict.get("masks", []) + mask_scales += [5.0] * len(extended_prompt_dict.get("masks", [])) + return prompt, local_prompts, masks, mask_scales \ No newline at end of file diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 57862d7..8d6a246 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -33,7 +33,7 @@ class FluxImagePipeline(BasePipeline): self.vae_encoder = model_manager.fetch_model("flux_vae_encoder") self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2) self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) - self.prompter.load_prompt_extenders(model_manager,prompt_extender_classes) + self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes) @staticmethod @@ -107,13 +107,8 @@ class FluxImagePipeline(BasePipeline): latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) # Extend prompt - if len(self.prompter.extenders) > 0: - extended_prompt_dict = self.prompter.extend_prompt(prompt) - prompt = extended_prompt_dict.get("prompt", prompt) - local_prompts += extended_prompt_dict.get("prompts", []) - masks += extended_prompt_dict.get("masks",[]) - mask_scales += [5.0 for _ in range(len(extended_prompt_dict.get("masks",[])))] - + prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, positive=True) if cfg_scale != 1.0: diff --git a/diffsynth/prompters/__init__.py b/diffsynth/prompters/__init__.py index 530eece..6bf909e 100644 --- a/diffsynth/prompters/__init__.py +++ b/diffsynth/prompters/__init__.py @@ -5,3 +5,4 @@ from .sd3_prompter import SD3Prompter from .hunyuan_dit_prompter import HunyuanDiTPrompter from .kolors_prompter import KolorsPrompter from .flux_prompter import FluxPrompter +from .omost import OmostPromter diff --git a/diffsynth/prompters/base_prompter.py b/diffsynth/prompters/base_prompter.py index 47215fd..9f0101a 100644 --- a/diffsynth/prompters/base_prompter.py +++ b/diffsynth/prompters/base_prompter.py @@ -37,12 +37,12 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None): class BasePrompter: - def __init__(self, refiners=[],extenders = []): + def __init__(self, refiners=[], extenders=[]): self.refiners = refiners self.extenders = extenders - def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): # manager + def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): for refiner_class in refiner_classes: refiner = refiner_class.from_model_manager(model_manager) self.refiners.append(refiner) @@ -63,7 +63,7 @@ class BasePrompter: return prompt @torch.no_grad() - def extend_prompt(self,prompt:str,positive = True): + def extend_prompt(self, prompt:str, positive=True): extended_prompt = dict(prompt=prompt) for extender in self.extenders: extended_prompt = extender(extended_prompt) diff --git a/diffsynth/prompters/omost.py b/diffsynth/prompters/omost.py index 1805e92..39999ce 100644 --- a/diffsynth/prompters/omost.py +++ b/diffsynth/prompters/omost.py @@ -1,6 +1,4 @@ - -# from .prompt_refiners import BeautifulPrompt -from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer +from transformers import AutoTokenizer, TextIteratorStreamer import difflib import torch import numpy as np @@ -225,10 +223,6 @@ class Canvas: prefixes=component['prefixes'], suffixes=component['suffixes'] )) - - import pickle - with open("tmp.pkl","wb+") as f: - pickle.dump(bag_of_conditions,f) return dict( initial_latent=initial_latent, @@ -261,10 +255,6 @@ class OmostPromter(torch.nn.Module): @staticmethod def from_model_manager(model_manager: ModelManager): - # model, model_path = model_manager.fetch_model("omost", require_model_path=True) - # omost = OmostPromter(tokenizer_path=model_path, model=model) - # return omost - print(model_manager) model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True) tokenizer = AutoTokenizer.from_pretrained(model_path) omost = OmostPromter( diff --git a/examples/image_synthesis/omost_flux_text_to_image.py b/examples/image_synthesis/omost_flux_text_to_image.py index 57450e6..7562342 100644 --- a/examples/image_synthesis/omost_flux_text_to_image.py +++ b/examples/image_synthesis/omost_flux_text_to_image.py @@ -1,8 +1,7 @@ - import torch -from diffsynth import download_models,FluxImagePipeline -from diffsynth.models.model_manager import ModelManager -from diffsynth.prompters.omost import OmostPromter +from diffsynth import download_models, ModelManager, OmostPromter, FluxImagePipeline + + download_models(["OmostPrompt"]) download_models(["FLUX.1-dev"]) @@ -15,10 +14,11 @@ model_manager.load_models([ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" ]) -pipe = FluxImagePipeline.from_model_manager(model_manager,prompt_extender_classes=[OmostPromter]) +pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter]) -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," -image = pipe("generate an image of a witch who is releasing ice and fire magic", - num_inference_steps=30, embedded_guidance=3.5, - negative_prompt=negative_prompt) +torch.manual_seed(0) +image = pipe( + prompt="an image of a witch who is releasing ice and fire magic", + num_inference_steps=30, embedded_guidance=3.5 +) image.save("image_omost.jpg")