diff --git a/diffsynth/models/qwen_image_controlnet.py b/diffsynth/models/qwen_image_controlnet.py index 67e6442..7c18011 100644 --- a/diffsynth/models/qwen_image_controlnet.py +++ b/diffsynth/models/qwen_image_controlnet.py @@ -54,8 +54,8 @@ class QwenImageControlNet(torch.nn.Module): img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() - image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) - controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) + image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) image = torch.concat([image, controlnet_conditioning], dim=-1) image = self.img_in(image)