support qwen-image-edit lowres fix

This commit is contained in:
mi804
2025-08-19 20:15:36 +08:00
parent 3a9621f6da
commit 838b8109b1
8 changed files with 133 additions and 6 deletions

View File

@@ -166,6 +166,66 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
def forward_sampling(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
frame_0, height_0, width_0 = video_fhw[0]
rope_key_0 = f"0_{height_0}_{width_0}"
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
h_indices = torch.linspace(0, height_0 - 1, height).long()
w_indices = torch.linspace(0, width_0 - 1, width).long()
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
seq_lens = frame * height * width
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone()
vid_freqs.append(self.rope_cache[rope_key].contiguous())
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
class QwenFeedForward(nn.Module):
def __init__(
self,

View File

@@ -68,10 +68,10 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
QwenImageUnit_BlockwiseControlNet(),
QwenImageUnit_EditImageEmbedder(),
]
self.model_fn = model_fn_qwen_image
@@ -280,6 +280,7 @@ class QwenImagePipeline(BasePipeline):
eligen_enable_on_negative: bool = False,
# Edit Image
edit_image: Image.Image = None,
edit_rope_interpolation: bool = False,
# FP8
enable_fp8_attention: bool = False,
# Tile
@@ -310,7 +311,7 @@ class QwenImagePipeline(BasePipeline):
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image,
"edit_image": edit_image, "edit_rope_interpolation": edit_rope_interpolation,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -583,11 +584,11 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "height", "width", "tiled", "tile_size", "tile_stride"),
input_params=("edit_image", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride):
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride):
if edit_image is None:
return {}
pipe.load_models_to_device(['vae'])
@@ -616,6 +617,7 @@ def model_fn_qwen_image(
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False,
**kwargs
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
@@ -640,7 +642,10 @@ def model_fn_qwen_image(
)
else:
text = dit.txt_in(dit.txt_norm(prompt_emb))
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
if edit_rope_interpolation:
image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device)
else:
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
attention_mask = None
if blockwise_controlnet_conditioning is not None: