mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
fix bugs
This commit is contained in:
@@ -78,8 +78,7 @@ huggingface_model_loader_configs = [
|
|||||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
||||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||||
# ("AutoModelForCausalLM", "transformers","omost_prompt",None),
|
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama","omost_prompt",None),
|
|
||||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||||
]
|
]
|
||||||
patch_model_loader_configs = [
|
patch_model_loader_configs = [
|
||||||
@@ -222,15 +221,14 @@ preset_models_on_modelscope = {
|
|||||||
],
|
],
|
||||||
# Omost prompt
|
# Omost prompt
|
||||||
"OmostPrompt":[
|
"OmostPrompt":[
|
||||||
("Omost/omost-llama-3-8b-4bits","model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","configuration.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits","special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
|
||||||
],
|
],
|
||||||
|
|
||||||
# Translator
|
# Translator
|
||||||
@@ -311,5 +309,5 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"ControlNet_union_sdxl_promax",
|
"ControlNet_union_sdxl_promax",
|
||||||
"FLUX.1-dev",
|
"FLUX.1-dev",
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
"OmostPrompt"
|
"OmostPrompt",
|
||||||
]
|
]
|
||||||
@@ -50,4 +50,13 @@ class BasePipeline(torch.nn.Module):
|
|||||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
||||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
||||||
|
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
||||||
|
prompt = extended_prompt_dict.get("prompt", prompt)
|
||||||
|
local_prompts += extended_prompt_dict.get("prompts", [])
|
||||||
|
masks += extended_prompt_dict.get("masks", [])
|
||||||
|
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
|
||||||
|
return prompt, local_prompts, masks, mask_scales
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
||||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||||
self.prompter.load_prompt_extenders(model_manager,prompt_extender_classes)
|
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -107,13 +107,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
# Extend prompt
|
# Extend prompt
|
||||||
if len(self.prompter.extenders) > 0:
|
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
|
||||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
|
||||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
|
||||||
masks += extended_prompt_dict.get("masks",[])
|
|
||||||
mask_scales += [5.0 for _ in range(len(extended_prompt_dict.get("masks",[])))]
|
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
|
|||||||
@@ -5,3 +5,4 @@ from .sd3_prompter import SD3Prompter
|
|||||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||||
from .kolors_prompter import KolorsPrompter
|
from .kolors_prompter import KolorsPrompter
|
||||||
from .flux_prompter import FluxPrompter
|
from .flux_prompter import FluxPrompter
|
||||||
|
from .omost import OmostPromter
|
||||||
|
|||||||
@@ -37,12 +37,12 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
|||||||
|
|
||||||
|
|
||||||
class BasePrompter:
|
class BasePrompter:
|
||||||
def __init__(self, refiners=[],extenders = []):
|
def __init__(self, refiners=[], extenders=[]):
|
||||||
self.refiners = refiners
|
self.refiners = refiners
|
||||||
self.extenders = extenders
|
self.extenders = extenders
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): # manager
|
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
||||||
for refiner_class in refiner_classes:
|
for refiner_class in refiner_classes:
|
||||||
refiner = refiner_class.from_model_manager(model_manager)
|
refiner = refiner_class.from_model_manager(model_manager)
|
||||||
self.refiners.append(refiner)
|
self.refiners.append(refiner)
|
||||||
@@ -63,7 +63,7 @@ class BasePrompter:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extend_prompt(self,prompt:str,positive = True):
|
def extend_prompt(self, prompt:str, positive=True):
|
||||||
extended_prompt = dict(prompt=prompt)
|
extended_prompt = dict(prompt=prompt)
|
||||||
for extender in self.extenders:
|
for extender in self.extenders:
|
||||||
extended_prompt = extender(extended_prompt)
|
extended_prompt = extender(extended_prompt)
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
|
from transformers import AutoTokenizer, TextIteratorStreamer
|
||||||
# from .prompt_refiners import BeautifulPrompt
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
||||||
import difflib
|
import difflib
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -225,10 +223,6 @@ class Canvas:
|
|||||||
prefixes=component['prefixes'],
|
prefixes=component['prefixes'],
|
||||||
suffixes=component['suffixes']
|
suffixes=component['suffixes']
|
||||||
))
|
))
|
||||||
|
|
||||||
import pickle
|
|
||||||
with open("tmp.pkl","wb+") as f:
|
|
||||||
pickle.dump(bag_of_conditions,f)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
initial_latent=initial_latent,
|
initial_latent=initial_latent,
|
||||||
@@ -261,10 +255,6 @@ class OmostPromter(torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager):
|
def from_model_manager(model_manager: ModelManager):
|
||||||
# model, model_path = model_manager.fetch_model("omost", require_model_path=True)
|
|
||||||
# omost = OmostPromter(tokenizer_path=model_path, model=model)
|
|
||||||
# return omost
|
|
||||||
print(model_manager)
|
|
||||||
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
omost = OmostPromter(
|
omost = OmostPromter(
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffsynth import download_models,FluxImagePipeline
|
from diffsynth import download_models, ModelManager, OmostPromter, FluxImagePipeline
|
||||||
from diffsynth.models.model_manager import ModelManager
|
|
||||||
from diffsynth.prompters.omost import OmostPromter
|
|
||||||
download_models(["OmostPrompt"])
|
download_models(["OmostPrompt"])
|
||||||
download_models(["FLUX.1-dev"])
|
download_models(["FLUX.1-dev"])
|
||||||
|
|
||||||
@@ -15,10 +14,11 @@ model_manager.load_models([
|
|||||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
])
|
])
|
||||||
|
|
||||||
pipe = FluxImagePipeline.from_model_manager(model_manager,prompt_extender_classes=[OmostPromter])
|
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
|
||||||
|
|
||||||
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
torch.manual_seed(0)
|
||||||
image = pipe("generate an image of a witch who is releasing ice and fire magic",
|
image = pipe(
|
||||||
num_inference_steps=30, embedded_guidance=3.5,
|
prompt="an image of a witch who is releasing ice and fire magic",
|
||||||
negative_prompt=negative_prompt)
|
num_inference_steps=30, embedded_guidance=3.5
|
||||||
|
)
|
||||||
image.save("image_omost.jpg")
|
image.save("image_omost.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user