diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index faa1d86..956e9ba 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -55,11 +55,14 @@ class BasePipeline(torch.nn.Module): def extend_prompt(self, prompt, local_prompts, masks, mask_scales): + local_prompts = local_prompts or [] + masks = masks or [] + mask_scales = mask_scales or [] extended_prompt_dict = self.prompter.extend_prompt(prompt) prompt = extended_prompt_dict.get("prompt", prompt) local_prompts += extended_prompt_dict.get("prompts", []) masks += extended_prompt_dict.get("masks", []) - mask_scales += [5.0] * len(extended_prompt_dict.get("masks", [])) + mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) return prompt, local_prompts, masks, mask_scales def enable_cpu_offload(self): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 67d961a..d014a92 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -75,9 +75,9 @@ class FluxImagePipeline(BasePipeline): def __call__( self, prompt, - local_prompts=[], - masks=[], - mask_scales=[], + local_prompts= None, + masks= None, + mask_scales= None, negative_prompt="", cfg_scale=1.0, embedded_guidance=0.0, diff --git a/diffsynth/prompters/base_prompter.py b/diffsynth/prompters/base_prompter.py index 9f0101a..136abd1 100644 --- a/diffsynth/prompters/base_prompter.py +++ b/diffsynth/prompters/base_prompter.py @@ -37,9 +37,9 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None): class BasePrompter: - def __init__(self, refiners=[], extenders=[]): - self.refiners = refiners - self.extenders = extenders + def __init__(self): + self.refiners = [] + self.extenders = [] def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): diff --git a/diffsynth/prompters/omost.py b/diffsynth/prompters/omost.py index 39999ce..81828ad 100644 --- a/diffsynth/prompters/omost.py +++ b/diffsynth/prompters/omost.py @@ -129,7 +129,7 @@ class Canvas: self.suffixes = [] return - def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, + def set_global_description(self, description: str, detailed_descriptions: list, tags: str, HTML_web_color_name: str): assert isinstance(description, str), 'Global description is not valid!' assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \ @@ -151,7 +151,7 @@ class Canvas: return def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, - detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, + detailed_descriptions: list, tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str): assert isinstance(description, str), 'Local description is wrong!' assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \ @@ -189,7 +189,8 @@ class Canvas: distance_to_viewer=distance_to_viewer, color=color, prefixes=prefixes, - suffixes=suffixes + suffixes=suffixes, + location=location, )) return @@ -211,7 +212,7 @@ class Canvas: # compute conditions bag_of_conditions = [ - dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes) + dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full") ] for i, component in enumerate(self.components): @@ -219,14 +220,15 @@ class Canvas: m = np.zeros(shape=(90, 90), dtype=np.float32) m[a:b, c:d] = 1.0 bag_of_conditions.append(dict( - mask=m, - prefixes=component['prefixes'], - suffixes=component['suffixes'] + mask = m, + prefixes = component['prefixes'], + suffixes = component['suffixes'], + location = component['location'], )) return dict( - initial_latent=initial_latent, - bag_of_conditions=bag_of_conditions, + initial_latent = initial_latent, + bag_of_conditions = bag_of_conditions, ) @@ -258,8 +260,9 @@ class OmostPromter(torch.nn.Module): model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True) tokenizer = AutoTokenizer.from_pretrained(model_path) omost = OmostPromter( - model=model, - tokenizer=tokenizer, + model= model, + tokenizer = tokenizer, + device = model_manager.device ) return omost @@ -271,13 +274,16 @@ class OmostPromter(torch.nn.Module): input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device) streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) - + attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device) + generate_kwargs = dict( - input_ids=input_ids, - streamer=streamer, + input_ids = input_ids, + streamer = streamer, # stopping_criteria=stopping_criteria, # max_new_tokens=max_new_tokens, - do_sample=True, + do_sample = True, + attention_mask = attention_mask, + pad_token_id = self.tokenizer.eos_token_id, # temperature=temperature, # top_p=top_p, ) @@ -290,7 +296,7 @@ class OmostPromter(torch.nn.Module): canvas = Canvas.from_bot_response(llm_outputs) canvas_output = canvas.process() - prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]] + prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]] canvas_output["prompt"] = prompts[0] canvas_output["prompts"] = prompts[1:] @@ -302,8 +308,14 @@ class OmostPromter(torch.nn.Module): masks.append(Image.fromarray(mask)) canvas_output["masks"] = masks - prompt_dict.update(canvas_output) + print(f"Your prompt is extended by Omost:\n") + cnt = 0 + for component,pmt in zip(canvas_output["bag_of_conditions"],prompts): + loc = component["location"] + cnt += 1 + print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n") + return prompt_dict diff --git a/examples/image_synthesis/omost_flux_text_to_image.py b/examples/image_synthesis/omost_flux_text_to_image.py index 7562342..e6e5d1d 100644 --- a/examples/image_synthesis/omost_flux_text_to_image.py +++ b/examples/image_synthesis/omost_flux_text_to_image.py @@ -14,11 +14,22 @@ model_manager.load_models([ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" ]) -pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter]) +pipe_omost = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter]) +pipe = FluxImagePipeline.from_model_manager(model_manager) -torch.manual_seed(0) -image = pipe( - prompt="an image of a witch who is releasing ice and fire magic", +prompt = "A witch uses ice magic to fight against wild beasts" +seed = 7 + +torch.manual_seed(seed) +image = pipe_omost( + prompt=prompt, num_inference_steps=30, embedded_guidance=3.5 ) -image.save("image_omost.jpg") +image.save(f"image_omost.jpg") + +torch.manual_seed(seed) +image2= pipe( + prompt=prompt, + num_inference_steps=30, embedded_guidance=3.5 +) +image2.save(f"image.jpg") \ No newline at end of file