mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
support qwen-image-edit
This commit is contained in:
@@ -63,8 +63,8 @@ class QwenEmbedRope(nn.Module):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
|
||||
@@ -52,7 +52,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
from transformers import Qwen2Tokenizer
|
||||
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
@@ -60,6 +60,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
self.vae: QwenImageVAE = None
|
||||
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
|
||||
self.tokenizer: Qwen2Tokenizer = None
|
||||
self.processor: Qwen2VLProcessor = None
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.in_iteration_models = ("dit", "blockwise_controlnet")
|
||||
self.units = [
|
||||
@@ -69,6 +70,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
QwenImageUnit_BlockwiseControlNet(),
|
||||
QwenImageUnit_EditImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
|
||||
@@ -218,6 +220,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
processor_config: ModelConfig = None,
|
||||
):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
@@ -239,6 +242,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
tokenizer_config.download_if_necessary()
|
||||
from transformers import Qwen2Tokenizer
|
||||
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
|
||||
if processor_config is not None:
|
||||
processor_config.download_if_necessary()
|
||||
from transformers import Qwen2VLProcessor
|
||||
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -266,6 +273,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
eligen_entity_prompts: list[str] = None,
|
||||
eligen_entity_masks: list[Image.Image] = None,
|
||||
eligen_enable_on_negative: bool = False,
|
||||
# Edit Image
|
||||
edit_image: Image.Image = None,
|
||||
# FP8
|
||||
enable_fp8_attention: bool = False,
|
||||
# Tile
|
||||
@@ -295,6 +304,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
"edit_image": edit_image,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -366,13 +376,13 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": None}
|
||||
|
||||
|
||||
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
@@ -383,18 +393,35 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt) -> dict:
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
||||
if edit_image is None:
|
||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 34
|
||||
else:
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if txt_tokens.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
|
||||
# Qwen-Image-Edit model
|
||||
if pipe.processor is not None:
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
# Qwen-Image model
|
||||
elif pipe.tokenizer is not None:
|
||||
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||
if model_inputs.input_ids.shape[1] >= 1024:
|
||||
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||
else:
|
||||
assert False, "QwenImagePipeline requires either tokenizer or processor to be loaded."
|
||||
|
||||
if 'pixel_values' in model_inputs:
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
else:
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
|
||||
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
@@ -528,6 +555,23 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
|
||||
return {"blockwise_controlnet_conditioning": conditionings}
|
||||
|
||||
|
||||
class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
edit_image = edit_image.resize((width, height))
|
||||
pipe.load_models_to_device(['vae'])
|
||||
edit_image = pipe.preprocess_image(edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"edit_latents": edit_latents}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
|
||||
@@ -544,6 +588,7 @@ def model_fn_qwen_image(
|
||||
entity_prompt_emb=None,
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
edit_latents=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
@@ -554,7 +599,13 @@ def model_fn_qwen_image(
|
||||
timestep = timestep / 1000
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
|
||||
if edit_latents is not None:
|
||||
img_shapes[0] = (img_shapes[0][0] + edit_latents.shape[0], img_shapes[0][1], img_shapes[0][2])
|
||||
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image_seq_len = image.shape[1]
|
||||
image = torch.cat([image, edit_image], dim=1)
|
||||
|
||||
image = dit.img_in(image)
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
|
||||
@@ -593,6 +644,8 @@ def model_fn_qwen_image(
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
if edit_latents is not None:
|
||||
image = image[:, :image_seq_len]
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return latents
|
||||
|
||||
@@ -552,4 +552,6 @@ def qwen_image_parser():
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
parser.add_argument("--edit_model", default=False, action="store_true", help="Whether to use Qwen-Image-Edit. If True, the model will be used for image editing.")
|
||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user