From f7f5c075702030760a6cf4ce92db1c649d1820f4 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 13 Aug 2025 17:23:00 +0800 Subject: [PATCH] fix long prompt for qwen-image --- diffsynth/models/qwen_image_dit.py | 33 +++++++++++++++++++++++++++++- diffsynth/pipelines/qwen_image.py | 4 +++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index 7841d50..b60b0e4 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -90,8 +90,39 @@ class QwenEmbedRope(nn.Module): ) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - + + + def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + _, height, width = video_fhw + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + required_len = max_vid_index + max(txt_seq_lens) + cur_max_len = self.pos_freqs.shape[0] + if required_len <= cur_max_len: + return + + new_max_len = math.ceil(required_len / 512) * 512 + pos_index = torch.arange(new_max_len) + neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 + self.pos_freqs = torch.cat([ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], dim=1) + self.neg_freqs = torch.cat([ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], dim=1) + return + + def forward(self, video_fhw, txt_seq_lens, device): + self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index c475415..ef333f4 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -372,7 +372,9 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit): template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" drop_idx = 34 txt = [template.format(e) for e in prompt] - txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device) + if txt_tokens.input_ids.shape[1] >= 1024: + print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.") hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1] split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)