support qwen-image-edit

This commit is contained in:
mi804
2025-08-18 16:07:45 +08:00
parent 7dc49bd036
commit 9f6922bba9
14 changed files with 212 additions and 19 deletions

View File

@@ -63,8 +63,8 @@ class QwenEmbedRope(nn.Module):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),

View File

@@ -52,7 +52,7 @@ class QwenImagePipeline(BasePipeline):
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
from transformers import Qwen2Tokenizer
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.text_encoder: QwenImageTextEncoder = None
@@ -60,6 +60,7 @@ class QwenImagePipeline(BasePipeline):
self.vae: QwenImageVAE = None
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None
self.processor: Qwen2VLProcessor = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "blockwise_controlnet")
self.units = [
@@ -69,6 +70,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
QwenImageUnit_BlockwiseControlNet(),
QwenImageUnit_EditImageEmbedder(),
]
self.model_fn = model_fn_qwen_image
@@ -218,6 +220,7 @@ class QwenImagePipeline(BasePipeline):
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,
):
# Download and load models
model_manager = ModelManager()
@@ -239,6 +242,10 @@ class QwenImagePipeline(BasePipeline):
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
if processor_config is not None:
processor_config.download_if_necessary()
from transformers import Qwen2VLProcessor
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
return pipe
@@ -266,6 +273,8 @@ class QwenImagePipeline(BasePipeline):
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False,
# Edit Image
edit_image: Image.Image = None,
# FP8
enable_fp8_attention: bool = False,
# Tile
@@ -295,6 +304,7 @@ class QwenImagePipeline(BasePipeline):
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -366,13 +376,13 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": None}
class QwenImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
input_params=("edit_image",),
onload_model_names=("text_encoder",)
)
@@ -383,18 +393,35 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def process(self, pipe: QwenImagePipeline, prompt) -> dict:
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
if pipe.text_encoder is not None:
prompt = [prompt]
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
drop_idx = 34
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
if edit_image is None:
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
drop_idx = 34
else:
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
drop_idx = 64
txt = [template.format(e) for e in prompt]
txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
if txt_tokens.input_ids.shape[1] >= 1024:
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
# Qwen-Image-Edit model
if pipe.processor is not None:
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
# Qwen-Image model
elif pipe.tokenizer is not None:
model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
if model_inputs.input_ids.shape[1] >= 1024:
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
else:
assert False, "QwenImagePipeline requires either tokenizer or processor to be loaded."
if 'pixel_values' in model_inputs:
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
else:
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -528,6 +555,23 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
return {"blockwise_controlnet_conditioning": conditionings}
class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride):
if edit_image is None:
return {}
edit_image = edit_image.resize((width, height))
pipe.load_models_to_device(['vae'])
edit_image = pipe.preprocess_image(edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"edit_latents": edit_latents}
def model_fn_qwen_image(
dit: QwenImageDiT = None,
blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
@@ -544,6 +588,7 @@ def model_fn_qwen_image(
entity_prompt_emb=None,
entity_prompt_emb_mask=None,
entity_masks=None,
edit_latents=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
@@ -554,7 +599,13 @@ def model_fn_qwen_image(
timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
if edit_latents is not None:
img_shapes[0] = (img_shapes[0][0] + edit_latents.shape[0], img_shapes[0][1], img_shapes[0][2])
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image_seq_len = image.shape[1]
image = torch.cat([image, edit_image], dim=1)
image = dit.img_in(image)
conditioning = dit.time_text_embed(timestep, image.dtype)
@@ -593,6 +644,8 @@ def model_fn_qwen_image(
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
if edit_latents is not None:
image = image[:, :image_seq_len]
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return latents

View File

@@ -552,4 +552,6 @@ def qwen_image_parser():
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--edit_model", default=False, action="store_true", help="Whether to use Qwen-Image-Edit. If True, the model will be used for image editing.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
return parser