mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
qwen_image eligen
This commit is contained in:
@@ -158,7 +158,8 @@ class QwenDoubleStreamAttention(nn.Module):
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||
@@ -186,7 +187,7 @@ class QwenDoubleStreamAttention(nn.Module):
|
||||
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||
|
||||
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v)
|
||||
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask)
|
||||
|
||||
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
|
||||
|
||||
@@ -245,6 +246,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
text: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
@@ -260,6 +262,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image=img_modulated,
|
||||
text=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
image = image + img_gate * img_attn_out
|
||||
@@ -309,6 +312,69 @@ class QwenImageDiT(torch.nn.Module):
|
||||
self.proj_out = nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
||||
# prompt_emb
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# image_rotary_emb
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
||||
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
|
||||
# attention_mask
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||
total_seq_len = sum(seq_lens) + image.shape[1]
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
patched_masks.append(patched_mask)
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
# prompt-image attention mask
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
continue
|
||||
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
|
||||
@@ -38,6 +38,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
QwenImageUnit_NoiseInitializer(),
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
@@ -190,6 +191,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# EliGen
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
# Tile
|
||||
tiled: bool = False,
|
||||
tile_size: int = 128,
|
||||
@@ -213,6 +218,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -322,6 +328,84 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
return {}
|
||||
|
||||
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
onload_model_names=("text_encoder")
|
||||
)
|
||||
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
|
||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def preprocess_masks(self, pipe, masks, height, width, dim):
|
||||
out_masks = []
|
||||
for mask in masks:
|
||||
mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
out_masks.append(mask)
|
||||
return out_masks
|
||||
|
||||
def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height):
|
||||
entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
|
||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||
prompt_embs, prompt_emb_masks = [], []
|
||||
for entity_prompt in entity_prompts:
|
||||
prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt)
|
||||
prompt_embs.append(prompt_emb_dict['prompt_emb'])
|
||||
prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask'])
|
||||
return prompt_embs, prompt_emb_masks, entity_masks
|
||||
|
||||
def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale):
|
||||
entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height)
|
||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
||||
entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi)
|
||||
entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi)
|
||||
entity_masks_nega = entity_masks_posi
|
||||
else:
|
||||
entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None
|
||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask}
|
||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask}
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
|
||||
if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
|
||||
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
|
||||
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
|
||||
eligen_enable_on_negative, inputs_shared["cfg_scale"])
|
||||
inputs_posi.update(eligen_kwargs_posi)
|
||||
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
|
||||
inputs_nega.update(eligen_kwargs_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
@@ -331,6 +415,9 @@ def model_fn_qwen_image(
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs
|
||||
@@ -342,9 +429,17 @@ def model_fn_qwen_image(
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
image = dit.img_in(image)
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
if entity_prompt_emb is not None:
|
||||
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
||||
latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask,
|
||||
entity_masks, height, width, image, img_shapes,
|
||||
)
|
||||
else:
|
||||
text = dit.txt_in(dit.txt_norm(prompt_emb))
|
||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
attention_mask = None
|
||||
|
||||
for block in dit.transformer_blocks:
|
||||
text, image = gradient_checkpoint_forward(
|
||||
@@ -355,6 +450,7 @@ def model_fn_qwen_image(
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
|
||||
89
examples/qwen_image/model_inference/Qwen-Image-EliGen.py
Normal file
89
examples/qwen_image/model_inference/Qwen-Image-EliGen.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from modelscope import dataset_snapshot_download
|
||||
import random
|
||||
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
example_id = 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a long dress, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea\n"
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
entity_prompts = ["cliff", "sea", "red moon", "sailing boat", "a seated beautiful woman wearing red dress", "yellow long dress"]
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
for seed in range(20):
|
||||
image = pipe(global_prompt, seed=seed, num_inference_steps=40, eligen_entity_prompts=entity_prompts, eligen_entity_masks=masks, cfg_scale=4.0, height=1024, width=1024)
|
||||
image.save(f"workdirs/qwen_image/eligen_{seed}.jpg")
|
||||
|
||||
visualize_masks(image, masks, entity_prompts, f"workdirs/qwen_image/eligen_{seed}_mask.png")
|
||||
|
||||
image1 = pipe(global_prompt, seed=seed, num_inference_steps=40, height=1024, width=1024, cfg_scale=4.0)
|
||||
image1.save(f"workdirs/qwen_image/qwenimage_{seed}.jpg")
|
||||
19
examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh
Normal file
19
examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_eligen.json \
|
||||
--data_file_keys "image,eligen_entity_masks" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-EliGen_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "eligen_entity_masks,eligen_entity_prompts" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
@@ -0,0 +1,29 @@
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Qwen-Image_lora/epoch-4.safetensors")
|
||||
|
||||
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
image = pipe(global_prompt,
|
||||
seed=0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks)
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user