update omost (#190)

* update omost
This commit is contained in:
ZhouTianchen
2024-09-09 17:39:46 +08:00
committed by GitHub
parent 1887885274
commit 995f3374f1
5 changed files with 55 additions and 29 deletions

View File

@@ -55,11 +55,14 @@ class BasePipeline(torch.nn.Module):
def extend_prompt(self, prompt, local_prompts, masks, mask_scales): def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt) extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt) prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", []) local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", []) masks += extended_prompt_dict.get("masks", [])
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", [])) mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self): def enable_cpu_offload(self):

View File

@@ -75,9 +75,9 @@ class FluxImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts=[], local_prompts= None,
masks=[], masks= None,
mask_scales=[], mask_scales= None,
negative_prompt="", negative_prompt="",
cfg_scale=1.0, cfg_scale=1.0,
embedded_guidance=0.0, embedded_guidance=0.0,

View File

@@ -37,9 +37,9 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
class BasePrompter: class BasePrompter:
def __init__(self, refiners=[], extenders=[]): def __init__(self):
self.refiners = refiners self.refiners = []
self.extenders = extenders self.extenders = []
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):

View File

@@ -129,7 +129,7 @@ class Canvas:
self.suffixes = [] self.suffixes = []
return return
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
HTML_web_color_name: str): HTML_web_color_name: str):
assert isinstance(description, str), 'Global description is not valid!' assert isinstance(description, str), 'Global description is not valid!'
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
@@ -151,7 +151,7 @@ class Canvas:
return return
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, detailed_descriptions: list, tags: str, atmosphere: str, style: str,
quality_meta: str, HTML_web_color_name: str): quality_meta: str, HTML_web_color_name: str):
assert isinstance(description, str), 'Local description is wrong!' assert isinstance(description, str), 'Local description is wrong!'
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \ assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
@@ -189,7 +189,8 @@ class Canvas:
distance_to_viewer=distance_to_viewer, distance_to_viewer=distance_to_viewer,
color=color, color=color,
prefixes=prefixes, prefixes=prefixes,
suffixes=suffixes suffixes=suffixes,
location=location,
)) ))
return return
@@ -211,7 +212,7 @@ class Canvas:
# compute conditions # compute conditions
bag_of_conditions = [ bag_of_conditions = [
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes) dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
] ]
for i, component in enumerate(self.components): for i, component in enumerate(self.components):
@@ -221,7 +222,8 @@ class Canvas:
bag_of_conditions.append(dict( bag_of_conditions.append(dict(
mask = m, mask = m,
prefixes = component['prefixes'], prefixes = component['prefixes'],
suffixes=component['suffixes'] suffixes = component['suffixes'],
location = component['location'],
)) ))
return dict( return dict(
@@ -260,6 +262,7 @@ class OmostPromter(torch.nn.Module):
omost = OmostPromter( omost = OmostPromter(
model= model, model= model,
tokenizer = tokenizer, tokenizer = tokenizer,
device = model_manager.device
) )
return omost return omost
@@ -271,6 +274,7 @@ class OmostPromter(torch.nn.Module):
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device) input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
generate_kwargs = dict( generate_kwargs = dict(
input_ids = input_ids, input_ids = input_ids,
@@ -278,6 +282,8 @@ class OmostPromter(torch.nn.Module):
# stopping_criteria=stopping_criteria, # stopping_criteria=stopping_criteria,
# max_new_tokens=max_new_tokens, # max_new_tokens=max_new_tokens,
do_sample = True, do_sample = True,
attention_mask = attention_mask,
pad_token_id = self.tokenizer.eos_token_id,
# temperature=temperature, # temperature=temperature,
# top_p=top_p, # top_p=top_p,
) )
@@ -290,7 +296,7 @@ class OmostPromter(torch.nn.Module):
canvas = Canvas.from_bot_response(llm_outputs) canvas = Canvas.from_bot_response(llm_outputs)
canvas_output = canvas.process() canvas_output = canvas.process()
prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]] prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
canvas_output["prompt"] = prompts[0] canvas_output["prompt"] = prompts[0]
canvas_output["prompts"] = prompts[1:] canvas_output["prompts"] = prompts[1:]
@@ -302,8 +308,14 @@ class OmostPromter(torch.nn.Module):
masks.append(Image.fromarray(mask)) masks.append(Image.fromarray(mask))
canvas_output["masks"] = masks canvas_output["masks"] = masks
prompt_dict.update(canvas_output) prompt_dict.update(canvas_output)
print(f"Your prompt is extended by Omost:\n")
cnt = 0
for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
loc = component["location"]
cnt += 1
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
return prompt_dict return prompt_dict

View File

@@ -14,11 +14,22 @@ model_manager.load_models([
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors" "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
]) ])
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter]) pipe_omost = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
pipe = FluxImagePipeline.from_model_manager(model_manager)
torch.manual_seed(0) prompt = "A witch uses ice magic to fight against wild beasts"
image = pipe( seed = 7
prompt="an image of a witch who is releasing ice and fire magic",
torch.manual_seed(seed)
image = pipe_omost(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5 num_inference_steps=30, embedded_guidance=3.5
) )
image.save("image_omost.jpg") image.save(f"image_omost.jpg")
torch.manual_seed(seed)
image2= pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
)
image2.save(f"image.jpg")