Flux lora (#205)

This commit is contained in:
Zhongjie Duan
2024-09-12 16:49:30 +08:00
committed by GitHub
parent a8cb4a21d1
commit c21ed1e478
2 changed files with 3 additions and 3 deletions

View File

@@ -9,6 +9,7 @@ from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
from einops import repeat
@@ -103,7 +104,8 @@ class SDXLImagePipeline(BasePipeline):
def prepare_extra_input(self, latents=None):
height, width = latents.shape[2] * 8, latents.shape[3] * 8
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
return {"add_time_id": add_time_id}
@torch.no_grad()

View File

@@ -49,8 +49,6 @@ class FluxPrompter(BasePrompter):
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb