mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
qwen-image dit original forward fix
This commit is contained in:
@@ -422,7 +422,7 @@ class QwenImageDiT(torch.nn.Module):
|
|||||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||||
|
|
||||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
image = self.img_in(image)
|
image = self.img_in(image)
|
||||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||||
|
|
||||||
@@ -441,7 +441,7 @@ class QwenImageDiT(torch.nn.Module):
|
|||||||
image = self.norm_out(image, conditioning)
|
image = self.norm_out(image, conditioning)
|
||||||
image = self.proj_out(image)
|
image = self.proj_out(image)
|
||||||
|
|
||||||
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user