mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -55,11 +55,14 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
||||
local_prompts = local_prompts or []
|
||||
masks = masks or []
|
||||
mask_scales = mask_scales or []
|
||||
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
||||
prompt = extended_prompt_dict.get("prompt", prompt)
|
||||
local_prompts += extended_prompt_dict.get("prompts", [])
|
||||
masks += extended_prompt_dict.get("masks", [])
|
||||
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
|
||||
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
||||
return prompt, local_prompts, masks, mask_scales
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
|
||||
@@ -75,9 +75,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
local_prompts= None,
|
||||
masks= None,
|
||||
mask_scales= None,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=0.0,
|
||||
|
||||
@@ -37,9 +37,9 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
||||
|
||||
|
||||
class BasePrompter:
|
||||
def __init__(self, refiners=[], extenders=[]):
|
||||
self.refiners = refiners
|
||||
self.extenders = extenders
|
||||
def __init__(self):
|
||||
self.refiners = []
|
||||
self.extenders = []
|
||||
|
||||
|
||||
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
||||
|
||||
@@ -129,7 +129,7 @@ class Canvas:
|
||||
self.suffixes = []
|
||||
return
|
||||
|
||||
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
|
||||
def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
|
||||
HTML_web_color_name: str):
|
||||
assert isinstance(description, str), 'Global description is not valid!'
|
||||
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
|
||||
@@ -151,7 +151,7 @@ class Canvas:
|
||||
return
|
||||
|
||||
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
|
||||
detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
|
||||
detailed_descriptions: list, tags: str, atmosphere: str, style: str,
|
||||
quality_meta: str, HTML_web_color_name: str):
|
||||
assert isinstance(description, str), 'Local description is wrong!'
|
||||
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
|
||||
@@ -189,7 +189,8 @@ class Canvas:
|
||||
distance_to_viewer=distance_to_viewer,
|
||||
color=color,
|
||||
prefixes=prefixes,
|
||||
suffixes=suffixes
|
||||
suffixes=suffixes,
|
||||
location=location,
|
||||
))
|
||||
|
||||
return
|
||||
@@ -211,7 +212,7 @@ class Canvas:
|
||||
# compute conditions
|
||||
|
||||
bag_of_conditions = [
|
||||
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
|
||||
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
|
||||
]
|
||||
|
||||
for i, component in enumerate(self.components):
|
||||
@@ -219,14 +220,15 @@ class Canvas:
|
||||
m = np.zeros(shape=(90, 90), dtype=np.float32)
|
||||
m[a:b, c:d] = 1.0
|
||||
bag_of_conditions.append(dict(
|
||||
mask=m,
|
||||
prefixes=component['prefixes'],
|
||||
suffixes=component['suffixes']
|
||||
mask = m,
|
||||
prefixes = component['prefixes'],
|
||||
suffixes = component['suffixes'],
|
||||
location = component['location'],
|
||||
))
|
||||
|
||||
return dict(
|
||||
initial_latent=initial_latent,
|
||||
bag_of_conditions=bag_of_conditions,
|
||||
initial_latent = initial_latent,
|
||||
bag_of_conditions = bag_of_conditions,
|
||||
)
|
||||
|
||||
|
||||
@@ -258,8 +260,9 @@ class OmostPromter(torch.nn.Module):
|
||||
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
omost = OmostPromter(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model= model,
|
||||
tokenizer = tokenizer,
|
||||
device = model_manager.device
|
||||
)
|
||||
return omost
|
||||
|
||||
@@ -271,13 +274,16 @@ class OmostPromter(torch.nn.Module):
|
||||
|
||||
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
|
||||
|
||||
generate_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
streamer=streamer,
|
||||
input_ids = input_ids,
|
||||
streamer = streamer,
|
||||
# stopping_criteria=stopping_criteria,
|
||||
# max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
do_sample = True,
|
||||
attention_mask = attention_mask,
|
||||
pad_token_id = self.tokenizer.eos_token_id,
|
||||
# temperature=temperature,
|
||||
# top_p=top_p,
|
||||
)
|
||||
@@ -290,7 +296,7 @@ class OmostPromter(torch.nn.Module):
|
||||
canvas = Canvas.from_bot_response(llm_outputs)
|
||||
canvas_output = canvas.process()
|
||||
|
||||
prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]]
|
||||
prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
|
||||
canvas_output["prompt"] = prompts[0]
|
||||
canvas_output["prompts"] = prompts[1:]
|
||||
|
||||
@@ -302,8 +308,14 @@ class OmostPromter(torch.nn.Module):
|
||||
masks.append(Image.fromarray(mask))
|
||||
|
||||
canvas_output["masks"] = masks
|
||||
|
||||
prompt_dict.update(canvas_output)
|
||||
print(f"Your prompt is extended by Omost:\n")
|
||||
cnt = 0
|
||||
for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
|
||||
loc = component["location"]
|
||||
cnt += 1
|
||||
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
|
||||
|
||||
return prompt_dict
|
||||
|
||||
|
||||
|
||||
@@ -14,11 +14,22 @@ model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
|
||||
pipe_omost = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="an image of a witch who is releasing ice and fire magic",
|
||||
prompt = "A witch uses ice magic to fight against wild beasts"
|
||||
seed = 7
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image = pipe_omost(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_omost.jpg")
|
||||
image.save(f"image_omost.jpg")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image2= pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5
|
||||
)
|
||||
image2.save(f"image.jpg")
|
||||
Reference in New Issue
Block a user