mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
@@ -86,6 +86,7 @@ huggingface_model_loader_configs = [
|
|||||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
||||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||||
|
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||||
]
|
]
|
||||||
@@ -227,6 +228,18 @@ preset_models_on_modelscope = {
|
|||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
],
|
],
|
||||||
|
# Omost prompt
|
||||||
|
"OmostPrompt":[
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
|
],
|
||||||
|
|
||||||
# Translator
|
# Translator
|
||||||
"opus-mt-zh-en": [
|
"opus-mt-zh-en": [
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||||
@@ -325,6 +338,7 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"ControlNet_union_sdxl_promax",
|
"ControlNet_union_sdxl_promax",
|
||||||
"FLUX.1-dev",
|
"FLUX.1-dev",
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
|
"OmostPrompt",
|
||||||
"ESRGAN_x4",
|
"ESRGAN_x4",
|
||||||
"RIFE",
|
"RIFE",
|
||||||
"CogVideoX-5B",
|
"CogVideoX-5B",
|
||||||
|
|||||||
@@ -119,7 +119,10 @@ def load_model_from_huggingface_folder(file_path, model_names, model_classes, to
|
|||||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model = model.to(device=device)
|
try:
|
||||||
|
model = model.to(device=device)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
loaded_model_names.append(model_name)
|
loaded_model_names.append(model_name)
|
||||||
loaded_models.append(model)
|
loaded_models.append(model)
|
||||||
return loaded_model_names, loaded_models
|
return loaded_model_names, loaded_models
|
||||||
|
|||||||
@@ -51,3 +51,12 @@ class BasePipeline(torch.nn.Module):
|
|||||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
||||||
|
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
||||||
|
prompt = extended_prompt_dict.get("prompt", prompt)
|
||||||
|
local_prompts += extended_prompt_dict.get("prompts", [])
|
||||||
|
masks += extended_prompt_dict.get("masks", [])
|
||||||
|
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
|
||||||
|
return prompt, local_prompts, masks, mask_scales
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return self.dit
|
return self.dit
|
||||||
|
|
||||||
|
|
||||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
|
||||||
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
||||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
||||||
self.dit = model_manager.fetch_model("flux_dit")
|
self.dit = model_manager.fetch_model("flux_dit")
|
||||||
@@ -33,15 +33,16 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
|
||||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||||
|
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
|
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]):
|
||||||
pipe = FluxImagePipeline(
|
pipe = FluxImagePipeline(
|
||||||
device=model_manager.device,
|
device=model_manager.device,
|
||||||
torch_dtype=model_manager.torch_dtype,
|
torch_dtype=model_manager.torch_dtype,
|
||||||
)
|
)
|
||||||
pipe.fetch_models(model_manager, prompt_refiner_classes)
|
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -105,6 +106,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
else:
|
else:
|
||||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
|
# Extend prompt
|
||||||
|
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
|
|||||||
@@ -5,4 +5,5 @@ from .sd3_prompter import SD3Prompter
|
|||||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||||
from .kolors_prompter import KolorsPrompter
|
from .kolors_prompter import KolorsPrompter
|
||||||
from .flux_prompter import FluxPrompter
|
from .flux_prompter import FluxPrompter
|
||||||
|
from .omost import OmostPromter
|
||||||
from .cog_prompter import CogPrompter
|
from .cog_prompter import CogPrompter
|
||||||
|
|||||||
@@ -37,15 +37,21 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
|||||||
|
|
||||||
|
|
||||||
class BasePrompter:
|
class BasePrompter:
|
||||||
def __init__(self, refiners=[]):
|
def __init__(self, refiners=[], extenders=[]):
|
||||||
self.refiners = refiners
|
self.refiners = refiners
|
||||||
|
self.extenders = extenders
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
|
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
||||||
for refiner_class in refiner_classes:
|
for refiner_class in refiner_classes:
|
||||||
refiner = refiner_class.from_model_manager(model_nameger)
|
refiner = refiner_class.from_model_manager(model_manager)
|
||||||
self.refiners.append(refiner)
|
self.refiners.append(refiner)
|
||||||
|
|
||||||
|
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
|
||||||
|
for extender_class in extender_classes:
|
||||||
|
extender = extender_class.from_model_manager(model_manager)
|
||||||
|
self.extenders.append(extender)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def process_prompt(self, prompt, positive=True):
|
def process_prompt(self, prompt, positive=True):
|
||||||
@@ -55,3 +61,10 @@ class BasePrompter:
|
|||||||
for refiner in self.refiners:
|
for refiner in self.refiners:
|
||||||
prompt = refiner(prompt, positive=positive)
|
prompt = refiner(prompt, positive=positive)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def extend_prompt(self, prompt:str, positive=True):
|
||||||
|
extended_prompt = dict(prompt=prompt)
|
||||||
|
for extender in self.extenders:
|
||||||
|
extended_prompt = extender(extended_prompt)
|
||||||
|
return extended_prompt
|
||||||
311
diffsynth/prompters/omost.py
Normal file
311
diffsynth/prompters/omost.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
from transformers import AutoTokenizer, TextIteratorStreamer
|
||||||
|
import difflib
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import re
|
||||||
|
from ..models.model_manager import ModelManager
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
valid_colors = { # r, g, b
|
||||||
|
'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
|
||||||
|
'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
|
||||||
|
'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
|
||||||
|
'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
|
||||||
|
'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
|
||||||
|
'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
|
||||||
|
'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
|
||||||
|
'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
|
||||||
|
'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
|
||||||
|
'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
|
||||||
|
'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
|
||||||
|
'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
|
||||||
|
'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
|
||||||
|
'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
|
||||||
|
'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
|
||||||
|
'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
|
||||||
|
'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
|
||||||
|
'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
|
||||||
|
'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
|
||||||
|
'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
|
||||||
|
'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
|
||||||
|
'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
|
||||||
|
'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
|
||||||
|
'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
|
||||||
|
'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
|
||||||
|
'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
|
||||||
|
'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
|
||||||
|
'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
|
||||||
|
'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
|
||||||
|
'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
|
||||||
|
'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
|
||||||
|
'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
|
||||||
|
'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
|
||||||
|
'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
|
||||||
|
'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
|
||||||
|
'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
|
||||||
|
'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
|
||||||
|
'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
|
||||||
|
'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
|
||||||
|
'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
|
||||||
|
'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
|
||||||
|
'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
|
||||||
|
'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
|
||||||
|
'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
|
||||||
|
'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
|
||||||
|
'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
|
||||||
|
'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
|
||||||
|
'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_locations = { # x, y in 90*90
|
||||||
|
'in the center': (45, 45),
|
||||||
|
'on the left': (15, 45),
|
||||||
|
'on the right': (75, 45),
|
||||||
|
'on the top': (45, 15),
|
||||||
|
'on the bottom': (45, 75),
|
||||||
|
'on the top-left': (15, 15),
|
||||||
|
'on the top-right': (75, 15),
|
||||||
|
'on the bottom-left': (15, 75),
|
||||||
|
'on the bottom-right': (75, 75)
|
||||||
|
}
|
||||||
|
|
||||||
|
valid_offsets = { # x, y in 90*90
|
||||||
|
'no offset': (0, 0),
|
||||||
|
'slightly to the left': (-10, 0),
|
||||||
|
'slightly to the right': (10, 0),
|
||||||
|
'slightly to the upper': (0, -10),
|
||||||
|
'slightly to the lower': (0, 10),
|
||||||
|
'slightly to the upper-left': (-10, -10),
|
||||||
|
'slightly to the upper-right': (10, -10),
|
||||||
|
'slightly to the lower-left': (-10, 10),
|
||||||
|
'slightly to the lower-right': (10, 10)}
|
||||||
|
|
||||||
|
valid_areas = { # w, h in 90*90
|
||||||
|
"a small square area": (50, 50),
|
||||||
|
"a small vertical area": (40, 60),
|
||||||
|
"a small horizontal area": (60, 40),
|
||||||
|
"a medium-sized square area": (60, 60),
|
||||||
|
"a medium-sized vertical area": (50, 80),
|
||||||
|
"a medium-sized horizontal area": (80, 50),
|
||||||
|
"a large square area": (70, 70),
|
||||||
|
"a large vertical area": (60, 90),
|
||||||
|
"a large horizontal area": (90, 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
def safe_str(x):
|
||||||
|
return x.strip(',. ') + '.'
|
||||||
|
|
||||||
|
def closest_name(input_str, options):
|
||||||
|
input_str = input_str.lower()
|
||||||
|
|
||||||
|
closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
|
||||||
|
assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
|
||||||
|
result = closest_match[0]
|
||||||
|
|
||||||
|
if result != input_str:
|
||||||
|
print(f'Automatically corrected [{input_str}] -> [{result}].')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
class Canvas:
|
||||||
|
@staticmethod
|
||||||
|
def from_bot_response(response: str):
|
||||||
|
|
||||||
|
matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
|
||||||
|
assert matched, 'Response does not contain codes!'
|
||||||
|
code_content = matched.group(1)
|
||||||
|
assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
|
||||||
|
local_vars = {'Canvas': Canvas}
|
||||||
|
exec(code_content, {}, local_vars)
|
||||||
|
canvas = local_vars.get('canvas', None)
|
||||||
|
assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.components = []
|
||||||
|
self.color = None
|
||||||
|
self.record_tags = True
|
||||||
|
self.prefixes = []
|
||||||
|
self.suffixes = []
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
|
||||||
|
HTML_web_color_name: str):
|
||||||
|
assert isinstance(description, str), 'Global description is not valid!'
|
||||||
|
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
||||||
|
'Global detailed_descriptions is not valid!'
|
||||||
|
assert isinstance(tags, str), 'Global tags is not valid!'
|
||||||
|
|
||||||
|
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
|
||||||
|
self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
|
||||||
|
|
||||||
|
self.prefixes = [description]
|
||||||
|
self.suffixes = detailed_descriptions
|
||||||
|
|
||||||
|
if self.record_tags:
|
||||||
|
self.suffixes = self.suffixes + [tags]
|
||||||
|
|
||||||
|
self.prefixes = [safe_str(x) for x in self.prefixes]
|
||||||
|
self.suffixes = [safe_str(x) for x in self.suffixes]
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
|
||||||
|
detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
|
||||||
|
quality_meta: str, HTML_web_color_name: str):
|
||||||
|
assert isinstance(description, str), 'Local description is wrong!'
|
||||||
|
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
|
||||||
|
f'The distance_to_viewer for [{description}] is not positive float number!'
|
||||||
|
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
||||||
|
f'The detailed_descriptions for [{description}] is not valid!'
|
||||||
|
assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
|
||||||
|
assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
|
||||||
|
assert isinstance(style, str), f'The style for [{description}] is not valid!'
|
||||||
|
assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
|
||||||
|
|
||||||
|
location = closest_name(location, valid_locations)
|
||||||
|
offset = closest_name(offset, valid_offsets)
|
||||||
|
area = closest_name(area, valid_areas)
|
||||||
|
HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
|
||||||
|
|
||||||
|
xb, yb = valid_locations[location]
|
||||||
|
xo, yo = valid_offsets[offset]
|
||||||
|
w, h = valid_areas[area]
|
||||||
|
rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
|
||||||
|
rect = [max(0, min(90, i)) for i in rect]
|
||||||
|
color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
|
||||||
|
|
||||||
|
prefixes = self.prefixes + [description]
|
||||||
|
suffixes = detailed_descriptions
|
||||||
|
|
||||||
|
if self.record_tags:
|
||||||
|
suffixes = suffixes + [tags, atmosphere, style, quality_meta]
|
||||||
|
|
||||||
|
prefixes = [safe_str(x) for x in prefixes]
|
||||||
|
suffixes = [safe_str(x) for x in suffixes]
|
||||||
|
|
||||||
|
self.components.append(dict(
|
||||||
|
rect=rect,
|
||||||
|
distance_to_viewer=distance_to_viewer,
|
||||||
|
color=color,
|
||||||
|
prefixes=prefixes,
|
||||||
|
suffixes=suffixes
|
||||||
|
))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
# sort components
|
||||||
|
self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
|
||||||
|
|
||||||
|
# compute initial latent
|
||||||
|
# print(self.color)
|
||||||
|
initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
|
||||||
|
|
||||||
|
for component in self.components:
|
||||||
|
a, b, c, d = component['rect']
|
||||||
|
initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
|
||||||
|
|
||||||
|
initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# compute conditions
|
||||||
|
|
||||||
|
bag_of_conditions = [
|
||||||
|
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, component in enumerate(self.components):
|
||||||
|
a, b, c, d = component['rect']
|
||||||
|
m = np.zeros(shape=(90, 90), dtype=np.float32)
|
||||||
|
m[a:b, c:d] = 1.0
|
||||||
|
bag_of_conditions.append(dict(
|
||||||
|
mask=m,
|
||||||
|
prefixes=component['prefixes'],
|
||||||
|
suffixes=component['suffixes']
|
||||||
|
))
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
initial_latent=initial_latent,
|
||||||
|
bag_of_conditions=bag_of_conditions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OmostPromter(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,model = None,tokenizer = None, template = "",device="cpu"):
|
||||||
|
super().__init__()
|
||||||
|
self.model=model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.device = device
|
||||||
|
if template == "":
|
||||||
|
template = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
|
||||||
|
```python
|
||||||
|
class Canvas:
|
||||||
|
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
|
||||||
|
assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
|
||||||
|
assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
|
||||||
|
assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
|
||||||
|
assert distance_to_viewer > 0
|
||||||
|
pass
|
||||||
|
```'''
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_manager(model_manager: ModelManager):
|
||||||
|
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
omost = OmostPromter(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
return omost
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self,prompt_dict:dict):
|
||||||
|
raw_prompt=prompt_dict["prompt"]
|
||||||
|
conversation = [{"role": "system", "content": self.template}]
|
||||||
|
conversation.append({"role": "user", "content": raw_prompt})
|
||||||
|
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
|
||||||
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
generate_kwargs = dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
streamer=streamer,
|
||||||
|
# stopping_criteria=stopping_criteria,
|
||||||
|
# max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
# temperature=temperature,
|
||||||
|
# top_p=top_p,
|
||||||
|
)
|
||||||
|
self.model.generate(**generate_kwargs)
|
||||||
|
outputs = []
|
||||||
|
for text in streamer:
|
||||||
|
outputs.append(text)
|
||||||
|
llm_outputs = "".join(outputs)
|
||||||
|
|
||||||
|
canvas = Canvas.from_bot_response(llm_outputs)
|
||||||
|
canvas_output = canvas.process()
|
||||||
|
|
||||||
|
prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]]
|
||||||
|
canvas_output["prompt"] = prompts[0]
|
||||||
|
canvas_output["prompts"] = prompts[1:]
|
||||||
|
|
||||||
|
raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
|
||||||
|
masks=[]
|
||||||
|
for mask in raw_masks:
|
||||||
|
mask[mask>0.5]=255
|
||||||
|
mask = np.stack([mask] * 3, axis=-1).astype("uint8")
|
||||||
|
masks.append(Image.fromarray(mask))
|
||||||
|
|
||||||
|
canvas_output["masks"] = masks
|
||||||
|
|
||||||
|
prompt_dict.update(canvas_output)
|
||||||
|
return prompt_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from ..models.model_manager import ModelManager
|
from ..models.model_manager import ModelManager
|
||||||
import torch
|
import torch
|
||||||
|
from .omost import OmostPromter
|
||||||
|
|
||||||
|
|
||||||
class BeautifulPrompt(torch.nn.Module):
|
class BeautifulPrompt(torch.nn.Module):
|
||||||
def __init__(self, tokenizer_path=None, model=None, template=""):
|
def __init__(self, tokenizer_path=None, model=None, template=""):
|
||||||
@@ -13,8 +12,8 @@ class BeautifulPrompt(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_nameger: ModelManager):
|
def from_model_manager(model_manager: ModelManager):
|
||||||
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
|
model, model_path = model_manager.fetch_model("beautiful_prompt", require_model_path=True)
|
||||||
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||||
if model_path.endswith("v2"):
|
if model_path.endswith("v2"):
|
||||||
template = """Converts a simple image description into a prompt. \
|
template = """Converts a simple image description into a prompt. \
|
||||||
@@ -63,8 +62,8 @@ class Translator(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_nameger: ModelManager):
|
def from_model_manager(model_manager: ModelManager):
|
||||||
model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
|
model, model_path = model_manager.fetch_model("translator", require_model_path=True)
|
||||||
translator = Translator(tokenizer_path=model_path, model=model)
|
translator = Translator(tokenizer_path=model_path, model=model)
|
||||||
return translator
|
return translator
|
||||||
|
|
||||||
|
|||||||
24
examples/image_synthesis/omost_flux_text_to_image.py
Normal file
24
examples/image_synthesis/omost_flux_text_to_image.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import download_models, ModelManager, OmostPromter, FluxImagePipeline
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["OmostPrompt"])
|
||||||
|
download_models(["FLUX.1-dev"])
|
||||||
|
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16)
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
|
])
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
image = pipe(
|
||||||
|
prompt="an image of a witch who is releasing ice and fire magic",
|
||||||
|
num_inference_steps=30, embedded_guidance=3.5
|
||||||
|
)
|
||||||
|
image.save("image_omost.jpg")
|
||||||
Reference in New Issue
Block a user