diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index 15f8747..b8f92bb 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -158,7 +158,8 @@ class QwenDoubleStreamAttention(nn.Module): self, image: torch.FloatTensor, text: torch.FloatTensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) @@ -186,7 +187,7 @@ class QwenDoubleStreamAttention(nn.Module): joint_k = torch.cat([txt_k, img_k], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2) - joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v) + joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask) joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype) @@ -245,6 +246,7 @@ class QwenImageTransformerBlock(nn.Module): text: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each @@ -260,6 +262,7 @@ class QwenImageTransformerBlock(nn.Module): image=img_modulated, text=txt_modulated, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, ) image = image + img_gate * img_attn_out @@ -309,6 +312,69 @@ class QwenImageDiT(torch.nn.Module): self.proj_out = nn.Linear(3072, 64) + def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes): + # prompt_emb + all_prompt_emb = entity_prompt_emb + [prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] + all_prompt_emb = torch.cat(all_prompt_emb, dim=1) + + # image_rotary_emb + txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask] + entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] + txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + + # attention_mask + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype) + entity_masks = entity_masks + [global_mask] + + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + total_seq_len = sum(seq_lens) + image.shape[1] + patched_masks = [] + for i in range(N): + patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + patched_masks.append(patched_mask) + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + # prompt-image attention mask + image_start = sum(seq_lens) + image_end = total_seq_len + cumsum = [0] + for length in seq_lens: + cumsum.append(cumsum[-1] + length) + for i in range(N): + prompt_start = cumsum[i] + prompt_end = cumsum[i+1] + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt attention mask, let the prompt tokens not attend to each other + for i in range(N): + for j in range(N): + if i == j: + continue + start_i, end_i = cumsum[i], cumsum[i+1] + start_j, end_j = cumsum[j], cumsum[j+1] + attention_mask[:, start_i:end_i, start_j:end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + + return all_prompt_emb, image_rotary_emb, attention_mask + def forward( self, latents=None, diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index deccd62..0611a7d 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -38,6 +38,7 @@ class QwenImagePipeline(BasePipeline): QwenImageUnit_NoiseInitializer(), QwenImageUnit_InputImageEmbedder(), QwenImageUnit_PromptEmbedder(), + QwenImageUnit_EntityControl(), ] self.model_fn = model_fn_qwen_image @@ -190,6 +191,10 @@ class QwenImagePipeline(BasePipeline): rand_device: str = "cpu", # Steps num_inference_steps: int = 30, + # EliGen + eligen_entity_prompts: list[str] = None, + eligen_entity_masks: list[Image.Image] = None, + eligen_enable_on_negative: bool = False, # Tile tiled: bool = False, tile_size: int = 128, @@ -213,6 +218,7 @@ class QwenImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -322,6 +328,84 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit): return {} +class QwenImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder") + ) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict: + if pipe.text_encoder is not None: + prompt = [prompt] + template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 34 + txt = [template.format(e) for e in prompt] + txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] + + split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask} + else: + return {} + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + prompt_embs, prompt_emb_masks = [], [] + for entity_prompt in entity_prompts: + prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt) + prompt_embs.append(prompt_emb_dict['prompt_emb']) + prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask']) + return prompt_embs, prompt_emb_masks, entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi) + entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + eligen_enable_on_negative, inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + def model_fn_qwen_image( dit: QwenImageDiT = None, @@ -331,6 +415,9 @@ def model_fn_qwen_image( prompt_emb_mask=None, height=None, width=None, + entity_prompt_emb=None, + entity_prompt_emb_mask=None, + entity_masks=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs @@ -342,9 +429,17 @@ def model_fn_qwen_image( image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) image = dit.img_in(image) - text = dit.txt_in(dit.txt_norm(prompt_emb)) conditioning = dit.time_text_embed(timestep, image.dtype) - image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + + if entity_prompt_emb is not None: + text, image_rotary_emb, attention_mask = dit.process_entity_masks( + latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, + entity_masks, height, width, image, img_shapes, + ) + else: + text = dit.txt_in(dit.txt_norm(prompt_emb)) + image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) + attention_mask = None for block in dit.transformer_blocks: text, image = gradient_checkpoint_forward( @@ -355,6 +450,7 @@ def model_fn_qwen_image( text=text, temb=conditioning, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, ) image = dit.norm_out(image, conditioning) diff --git a/examples/qwen_image/model_inference/Qwen-Image-EliGen.py b/examples/qwen_image/model_inference/Qwen-Image-EliGen.py new file mode 100644 index 0000000..ef06eef --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-EliGen.py @@ -0,0 +1,89 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image, ImageDraw, ImageFont +from modelscope import dataset_snapshot_download +import random + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +example_id = 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a long dress, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea\n" +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") +entity_prompts = ["cliff", "sea", "red moon", "sailing boat", "a seated beautiful woman wearing red dress", "yellow long dress"] +masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + +for seed in range(20): + image = pipe(global_prompt, seed=seed, num_inference_steps=40, eligen_entity_prompts=entity_prompts, eligen_entity_masks=masks, cfg_scale=4.0, height=1024, width=1024) + image.save(f"workdirs/qwen_image/eligen_{seed}.jpg") + + visualize_masks(image, masks, entity_prompts, f"workdirs/qwen_image/eligen_{seed}_mask.png") + + image1 = pipe(global_prompt, seed=seed, num_inference_steps=40, height=1024, width=1024, cfg_scale=4.0) + image1.save(f"workdirs/qwen_image/qwenimage_{seed}.jpg") diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh b/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh new file mode 100644 index 0000000..ea2e659 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh @@ -0,0 +1,19 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_eligen.json \ + --data_file_keys "image,eligen_entity_masks" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-EliGen_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "eligen_entity_masks,eligen_entity_prompts" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py new file mode 100644 index 0000000..90680d3 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py @@ -0,0 +1,29 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch +from PIL import Image + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image_lora/epoch-4.safetensors") + + +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + +image = pipe(global_prompt, + seed=0, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks) +image.save("image.jpg")