This commit is contained in:
Artiprocher
2024-09-04 12:48:32 +08:00
parent 0b066d3cb4
commit d70cd04b15
7 changed files with 36 additions and 43 deletions

View File

@@ -78,7 +78,6 @@ huggingface_model_loader_configs = [
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None), ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None), ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None), ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
# ("AutoModelForCausalLM", "transformers","omost_prompt",None),
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None), ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
] ]
@@ -227,7 +226,6 @@ preset_models_on_modelscope = {
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits","configuration.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
@@ -311,5 +309,5 @@ Preset_model_id: TypeAlias = Literal[
"ControlNet_union_sdxl_promax", "ControlNet_union_sdxl_promax",
"FLUX.1-dev", "FLUX.1-dev",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"OmostPrompt" "OmostPrompt",
] ]

View File

@@ -51,3 +51,12 @@ class BasePipeline(torch.nn.Module):
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred return noise_pred
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales

View File

@@ -107,12 +107,7 @@ class FluxImagePipeline(BasePipeline):
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Extend prompt # Extend prompt
if len(self.prompter.extenders) > 0: prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks",[])
mask_scales += [5.0 for _ in range(len(extended_prompt_dict.get("masks",[])))]
# Encode prompts # Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, positive=True) prompt_emb_posi = self.encode_prompt(prompt, positive=True)

View File

@@ -5,3 +5,4 @@ from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter from .kolors_prompter import KolorsPrompter
from .flux_prompter import FluxPrompter from .flux_prompter import FluxPrompter
from .omost import OmostPromter

View File

@@ -42,7 +42,7 @@ class BasePrompter:
self.extenders = extenders self.extenders = extenders
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): # manager def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
for refiner_class in refiner_classes: for refiner_class in refiner_classes:
refiner = refiner_class.from_model_manager(model_manager) refiner = refiner_class.from_model_manager(model_manager)
self.refiners.append(refiner) self.refiners.append(refiner)

View File

@@ -1,6 +1,4 @@
from transformers import AutoTokenizer, TextIteratorStreamer
# from .prompt_refiners import BeautifulPrompt
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import difflib import difflib
import torch import torch
import numpy as np import numpy as np
@@ -226,10 +224,6 @@ class Canvas:
suffixes=component['suffixes'] suffixes=component['suffixes']
)) ))
import pickle
with open("tmp.pkl","wb+") as f:
pickle.dump(bag_of_conditions,f)
return dict( return dict(
initial_latent=initial_latent, initial_latent=initial_latent,
bag_of_conditions=bag_of_conditions, bag_of_conditions=bag_of_conditions,
@@ -261,10 +255,6 @@ class OmostPromter(torch.nn.Module):
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager): def from_model_manager(model_manager: ModelManager):
# model, model_path = model_manager.fetch_model("omost", require_model_path=True)
# omost = OmostPromter(tokenizer_path=model_path, model=model)
# return omost
print(model_manager)
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True) model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
omost = OmostPromter( omost = OmostPromter(

View File

@@ -1,8 +1,7 @@
import torch import torch
from diffsynth import download_models,FluxImagePipeline from diffsynth import download_models, ModelManager, OmostPromter, FluxImagePipeline
from diffsynth.models.model_manager import ModelManager
from diffsynth.prompters.omost import OmostPromter
download_models(["OmostPrompt"]) download_models(["OmostPrompt"])
download_models(["FLUX.1-dev"]) download_models(["FLUX.1-dev"])
@@ -17,8 +16,9 @@ model_manager.load_models([
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter]) pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," torch.manual_seed(0)
image = pipe("generate an image of a witch who is releasing ice and fire magic", image = pipe(
num_inference_steps=30, embedded_guidance=3.5, prompt="an image of a witch who is releasing ice and fire magic",
negative_prompt=negative_prompt) num_inference_steps=30, embedded_guidance=3.5
)
image.save("image_omost.jpg") image.save("image_omost.jpg")