mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
@@ -55,11 +55,14 @@ class BasePipeline(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
||||||
|
local_prompts = local_prompts or []
|
||||||
|
masks = masks or []
|
||||||
|
mask_scales = mask_scales or []
|
||||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
||||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
prompt = extended_prompt_dict.get("prompt", prompt)
|
||||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
local_prompts += extended_prompt_dict.get("prompts", [])
|
||||||
masks += extended_prompt_dict.get("masks", [])
|
masks += extended_prompt_dict.get("masks", [])
|
||||||
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
|
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
||||||
return prompt, local_prompts, masks, mask_scales
|
return prompt, local_prompts, masks, mask_scales
|
||||||
|
|
||||||
def enable_cpu_offload(self):
|
def enable_cpu_offload(self):
|
||||||
|
|||||||
@@ -75,9 +75,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
local_prompts=[],
|
local_prompts= None,
|
||||||
masks=[],
|
masks= None,
|
||||||
mask_scales=[],
|
mask_scales= None,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
embedded_guidance=0.0,
|
embedded_guidance=0.0,
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
|||||||
|
|
||||||
|
|
||||||
class BasePrompter:
|
class BasePrompter:
|
||||||
def __init__(self, refiners=[], extenders=[]):
|
def __init__(self):
|
||||||
self.refiners = refiners
|
self.refiners = []
|
||||||
self.extenders = extenders
|
self.extenders = []
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class Canvas:
|
|||||||
self.suffixes = []
|
self.suffixes = []
|
||||||
return
|
return
|
||||||
|
|
||||||
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
|
def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
|
||||||
HTML_web_color_name: str):
|
HTML_web_color_name: str):
|
||||||
assert isinstance(description, str), 'Global description is not valid!'
|
assert isinstance(description, str), 'Global description is not valid!'
|
||||||
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
||||||
@@ -151,7 +151,7 @@ class Canvas:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
|
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
|
||||||
detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
|
detailed_descriptions: list, tags: str, atmosphere: str, style: str,
|
||||||
quality_meta: str, HTML_web_color_name: str):
|
quality_meta: str, HTML_web_color_name: str):
|
||||||
assert isinstance(description, str), 'Local description is wrong!'
|
assert isinstance(description, str), 'Local description is wrong!'
|
||||||
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
|
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
|
||||||
@@ -189,7 +189,8 @@ class Canvas:
|
|||||||
distance_to_viewer=distance_to_viewer,
|
distance_to_viewer=distance_to_viewer,
|
||||||
color=color,
|
color=color,
|
||||||
prefixes=prefixes,
|
prefixes=prefixes,
|
||||||
suffixes=suffixes
|
suffixes=suffixes,
|
||||||
|
location=location,
|
||||||
))
|
))
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -211,7 +212,7 @@ class Canvas:
|
|||||||
# compute conditions
|
# compute conditions
|
||||||
|
|
||||||
bag_of_conditions = [
|
bag_of_conditions = [
|
||||||
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
|
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, component in enumerate(self.components):
|
for i, component in enumerate(self.components):
|
||||||
@@ -219,14 +220,15 @@ class Canvas:
|
|||||||
m = np.zeros(shape=(90, 90), dtype=np.float32)
|
m = np.zeros(shape=(90, 90), dtype=np.float32)
|
||||||
m[a:b, c:d] = 1.0
|
m[a:b, c:d] = 1.0
|
||||||
bag_of_conditions.append(dict(
|
bag_of_conditions.append(dict(
|
||||||
mask=m,
|
mask = m,
|
||||||
prefixes=component['prefixes'],
|
prefixes = component['prefixes'],
|
||||||
suffixes=component['suffixes']
|
suffixes = component['suffixes'],
|
||||||
|
location = component['location'],
|
||||||
))
|
))
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
initial_latent=initial_latent,
|
initial_latent = initial_latent,
|
||||||
bag_of_conditions=bag_of_conditions,
|
bag_of_conditions = bag_of_conditions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -258,8 +260,9 @@ class OmostPromter(torch.nn.Module):
|
|||||||
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
omost = OmostPromter(
|
omost = OmostPromter(
|
||||||
model=model,
|
model= model,
|
||||||
tokenizer=tokenizer,
|
tokenizer = tokenizer,
|
||||||
|
device = model_manager.device
|
||||||
)
|
)
|
||||||
return omost
|
return omost
|
||||||
|
|
||||||
@@ -271,13 +274,16 @@ class OmostPromter(torch.nn.Module):
|
|||||||
|
|
||||||
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
|
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
|
||||||
|
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
input_ids=input_ids,
|
input_ids = input_ids,
|
||||||
streamer=streamer,
|
streamer = streamer,
|
||||||
# stopping_criteria=stopping_criteria,
|
# stopping_criteria=stopping_criteria,
|
||||||
# max_new_tokens=max_new_tokens,
|
# max_new_tokens=max_new_tokens,
|
||||||
do_sample=True,
|
do_sample = True,
|
||||||
|
attention_mask = attention_mask,
|
||||||
|
pad_token_id = self.tokenizer.eos_token_id,
|
||||||
# temperature=temperature,
|
# temperature=temperature,
|
||||||
# top_p=top_p,
|
# top_p=top_p,
|
||||||
)
|
)
|
||||||
@@ -290,7 +296,7 @@ class OmostPromter(torch.nn.Module):
|
|||||||
canvas = Canvas.from_bot_response(llm_outputs)
|
canvas = Canvas.from_bot_response(llm_outputs)
|
||||||
canvas_output = canvas.process()
|
canvas_output = canvas.process()
|
||||||
|
|
||||||
prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]]
|
prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
|
||||||
canvas_output["prompt"] = prompts[0]
|
canvas_output["prompt"] = prompts[0]
|
||||||
canvas_output["prompts"] = prompts[1:]
|
canvas_output["prompts"] = prompts[1:]
|
||||||
|
|
||||||
@@ -302,8 +308,14 @@ class OmostPromter(torch.nn.Module):
|
|||||||
masks.append(Image.fromarray(mask))
|
masks.append(Image.fromarray(mask))
|
||||||
|
|
||||||
canvas_output["masks"] = masks
|
canvas_output["masks"] = masks
|
||||||
|
|
||||||
prompt_dict.update(canvas_output)
|
prompt_dict.update(canvas_output)
|
||||||
|
print(f"Your prompt is extended by Omost:\n")
|
||||||
|
cnt = 0
|
||||||
|
for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
|
||||||
|
loc = component["location"]
|
||||||
|
cnt += 1
|
||||||
|
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
|
||||||
|
|
||||||
return prompt_dict
|
return prompt_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,22 @@ model_manager.load_models([
|
|||||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
])
|
])
|
||||||
|
|
||||||
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
|
pipe_omost = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
prompt = "A witch uses ice magic to fight against wild beasts"
|
||||||
image = pipe(
|
seed = 7
|
||||||
prompt="an image of a witch who is releasing ice and fire magic",
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image = pipe_omost(
|
||||||
|
prompt=prompt,
|
||||||
num_inference_steps=30, embedded_guidance=3.5
|
num_inference_steps=30, embedded_guidance=3.5
|
||||||
)
|
)
|
||||||
image.save("image_omost.jpg")
|
image.save(f"image_omost.jpg")
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image2= pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=30, embedded_guidance=3.5
|
||||||
|
)
|
||||||
|
image2.save(f"image.jpg")
|
||||||
Reference in New Issue
Block a user