mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
support qwen-image-edit lowres fix
This commit is contained in:
@@ -166,6 +166,66 @@ class QwenEmbedRope(nn.Module):
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
def forward_sampling(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
|
||||
frame_0, height_0, width_0 = video_fhw[0]
|
||||
|
||||
rope_key_0 = f"0_{height_0}_{width_0}"
|
||||
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
|
||||
h_indices = torch.linspace(0, height_0 - 1, height).long()
|
||||
w_indices = torch.linspace(0, width_0 - 1, width).long()
|
||||
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
|
||||
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
|
||||
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
|
||||
|
||||
seq_lens = frame * height * width
|
||||
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone()
|
||||
vid_freqs.append(self.rope_cache[rope_key].contiguous())
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user