mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
@@ -524,37 +524,63 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
def resize_image(self, image, target_area=384*384):
|
||||
width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1])
|
||||
return image.resize((width, height))
|
||||
|
||||
def encode_prompt(self, pipe: QwenImagePipeline, prompt):
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
txt = [template.format(e) for e in prompt]
|
||||
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if model_inputs.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
return split_hidden_states
|
||||
|
||||
def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
txt = [template.format(e) for e in prompt]
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
return split_hidden_states
|
||||
|
||||
def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image):
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))])
|
||||
txt = [template.format(base_img_prompt + e) for e in prompt]
|
||||
edit_image = [self.resize_image(image) for image in edit_image]
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
return split_hidden_states
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
||||
if edit_image is None:
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
split_hidden_states = self.encode_prompt(pipe, prompt)
|
||||
elif isinstance(edit_image, Image.Image):
|
||||
split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image)
|
||||
else:
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
txt = [template.format(e) for e in prompt]
|
||||
|
||||
# Qwen-Image-Edit model
|
||||
if pipe.processor is not None:
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
# Qwen-Image model
|
||||
elif pipe.tokenizer is not None:
|
||||
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if model_inputs.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
else:
|
||||
assert False, "QwenImagePipeline requires either tokenizer or processor to be loaded."
|
||||
|
||||
if 'pixel_values' in model_inputs:
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
else:
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image)
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
@@ -712,12 +738,23 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
||||
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
|
||||
pipe.load_models_to_device(['vae'])
|
||||
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if isinstance(edit_image, Image.Image):
|
||||
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
|
||||
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
else:
|
||||
resized_edit_image, edit_latents = [], []
|
||||
for image in edit_image:
|
||||
if edit_image_auto_resize:
|
||||
image = self.edit_image_auto_resize(image)
|
||||
resized_edit_image.append(image)
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
edit_latents.append(latents)
|
||||
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
||||
|
||||
|
||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -770,9 +807,10 @@ def model_fn_qwen_image(
|
||||
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
|
||||
image = torch.cat([image, context_image], dim=1)
|
||||
if edit_latents is not None:
|
||||
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
|
||||
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
|
||||
image = torch.cat([image, edit_image], dim=1)
|
||||
edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents]
|
||||
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
|
||||
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
|
||||
image = torch.cat([image] + edit_image, dim=1)
|
||||
|
||||
image = dit.img_in(image)
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
|
||||
Reference in New Issue
Block a user