mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
Merge pull request #791 from mi804/qwen-image-longprompt
fix long prompt for qwen-image
This commit is contained in:
@@ -90,8 +90,39 @@ class QwenEmbedRope(nn.Module):
|
|||||||
)
|
)
|
||||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||||
|
if isinstance(video_fhw, list):
|
||||||
|
video_fhw = video_fhw[0]
|
||||||
|
_, height, width = video_fhw
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width)
|
||||||
|
required_len = max_vid_index + max(txt_seq_lens)
|
||||||
|
cur_max_len = self.pos_freqs.shape[0]
|
||||||
|
if required_len <= cur_max_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_max_len = math.ceil(required_len / 512) * 512
|
||||||
|
pos_index = torch.arange(new_max_len)
|
||||||
|
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat([
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.neg_freqs = torch.cat([
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def forward(self, video_fhw, txt_seq_lens, device):
|
def forward(self, video_fhw, txt_seq_lens, device):
|
||||||
|
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||||
if self.pos_freqs.device != device:
|
if self.pos_freqs.device != device:
|
||||||
self.pos_freqs = self.pos_freqs.to(device)
|
self.pos_freqs = self.pos_freqs.to(device)
|
||||||
self.neg_freqs = self.neg_freqs.to(device)
|
self.neg_freqs = self.neg_freqs.to(device)
|
||||||
|
|||||||
@@ -372,7 +372,9 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
drop_idx = 34
|
drop_idx = 34
|
||||||
txt = [template.format(e) for e in prompt]
|
txt = [template.format(e) for e in prompt]
|
||||||
txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
txt_tokens = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
|
||||||
|
if txt_tokens.input_ids.shape[1] >= 1024:
|
||||||
|
print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {txt_tokens['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
|
||||||
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
|
||||||
|
|
||||||
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||||
|
|||||||
Reference in New Issue
Block a user