update omost (#190)

* update omost
This commit is contained in:
ZhouTianchen
2024-09-09 17:39:46 +08:00
committed by GitHub
parent 1887885274
commit 995f3374f1
5 changed files with 55 additions and 29 deletions

View File

@@ -55,11 +55,14 @@ class BasePipeline(torch.nn.Module):
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
local_prompts = local_prompts or []
masks = masks or []
mask_scales = mask_scales or []
extended_prompt_dict = self.prompter.extend_prompt(prompt)
prompt = extended_prompt_dict.get("prompt", prompt)
local_prompts += extended_prompt_dict.get("prompts", [])
masks += extended_prompt_dict.get("masks", [])
mask_scales += [5.0] * len(extended_prompt_dict.get("masks", []))
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self):

View File

@@ -75,9 +75,9 @@ class FluxImagePipeline(BasePipeline):
def __call__(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
local_prompts= None,
masks= None,
mask_scales= None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=0.0,

View File

@@ -37,9 +37,9 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
class BasePrompter:
def __init__(self, refiners=[], extenders=[]):
self.refiners = refiners
self.extenders = extenders
def __init__(self):
self.refiners = []
self.extenders = []
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):

View File

@@ -129,7 +129,7 @@ class Canvas:
self.suffixes = []
return
def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
HTML_web_color_name: str):
assert isinstance(description, str), 'Global description is not valid!'
assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
@@ -151,7 +151,7 @@ class Canvas:
return
def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
detailed_descriptions: list, tags: str, atmosphere: str, style: str,
quality_meta: str, HTML_web_color_name: str):
assert isinstance(description, str), 'Local description is wrong!'
assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
@@ -189,7 +189,8 @@ class Canvas:
distance_to_viewer=distance_to_viewer,
color=color,
prefixes=prefixes,
suffixes=suffixes
suffixes=suffixes,
location=location,
))
return
@@ -211,7 +212,7 @@ class Canvas:
# compute conditions
bag_of_conditions = [
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
]
for i, component in enumerate(self.components):
@@ -219,14 +220,15 @@ class Canvas:
m = np.zeros(shape=(90, 90), dtype=np.float32)
m[a:b, c:d] = 1.0
bag_of_conditions.append(dict(
mask=m,
prefixes=component['prefixes'],
suffixes=component['suffixes']
mask = m,
prefixes = component['prefixes'],
suffixes = component['suffixes'],
location = component['location'],
))
return dict(
initial_latent=initial_latent,
bag_of_conditions=bag_of_conditions,
initial_latent = initial_latent,
bag_of_conditions = bag_of_conditions,
)
@@ -258,8 +260,9 @@ class OmostPromter(torch.nn.Module):
model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
omost = OmostPromter(
model=model,
tokenizer=tokenizer,
model= model,
tokenizer = tokenizer,
device = model_manager.device
)
return omost
@@ -271,13 +274,16 @@ class OmostPromter(torch.nn.Module):
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
input_ids = input_ids,
streamer = streamer,
# stopping_criteria=stopping_criteria,
# max_new_tokens=max_new_tokens,
do_sample=True,
do_sample = True,
attention_mask = attention_mask,
pad_token_id = self.tokenizer.eos_token_id,
# temperature=temperature,
# top_p=top_p,
)
@@ -290,7 +296,7 @@ class OmostPromter(torch.nn.Module):
canvas = Canvas.from_bot_response(llm_outputs)
canvas_output = canvas.process()
prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]]
prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
canvas_output["prompt"] = prompts[0]
canvas_output["prompts"] = prompts[1:]
@@ -302,8 +308,14 @@ class OmostPromter(torch.nn.Module):
masks.append(Image.fromarray(mask))
canvas_output["masks"] = masks
prompt_dict.update(canvas_output)
print(f"Your prompt is extended by Omost:\n")
cnt = 0
for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
loc = component["location"]
cnt += 1
print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
return prompt_dict

View File

@@ -14,11 +14,22 @@ model_manager.load_models([
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
pipe_omost = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter])
pipe = FluxImagePipeline.from_model_manager(model_manager)
torch.manual_seed(0)
image = pipe(
prompt="an image of a witch who is releasing ice and fire magic",
prompt = "A witch uses ice magic to fight against wild beasts"
seed = 7
torch.manual_seed(seed)
image = pipe_omost(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
)
image.save("image_omost.jpg")
image.save(f"image_omost.jpg")
torch.manual_seed(seed)
image2= pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
)
image2.save(f"image.jpg")