From c21ed1e478975c1c80dc92c8e4a6e13e6a77719b Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:49:30 +0800 Subject: [PATCH] Flux lora (#205) --- diffsynth/pipelines/sdxl_image.py | 4 +++- diffsynth/prompters/flux_prompter.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/diffsynth/pipelines/sdxl_image.py b/diffsynth/pipelines/sdxl_image.py index 81f6b40..09b7dd4 100644 --- a/diffsynth/pipelines/sdxl_image.py +++ b/diffsynth/pipelines/sdxl_image.py @@ -9,6 +9,7 @@ from .dancer import lets_dance_xl from typing import List import torch from tqdm import tqdm +from einops import repeat @@ -103,7 +104,8 @@ class SDXLImagePipeline(BasePipeline): def prepare_extra_input(self, latents=None): height, width = latents.shape[2] * 8, latents.shape[3] * 8 - return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)} + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0]) + return {"add_time_id": add_time_id} @torch.no_grad() diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py index 5b40ecb..656f43b 100644 --- a/diffsynth/prompters/flux_prompter.py +++ b/diffsynth/prompters/flux_prompter.py @@ -49,8 +49,6 @@ class FluxPrompter(BasePrompter): truncation=True, ).input_ids.to(device) prompt_emb = text_encoder(input_ids) - prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) - return prompt_emb