mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
Flux lora (#205)
This commit is contained in:
@@ -9,6 +9,7 @@ from .dancer import lets_dance_xl
|
|||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -103,7 +104,8 @@ class SDXLImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
def prepare_extra_input(self, latents=None):
|
||||||
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
||||||
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
|
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
|
||||||
|
return {"add_time_id": add_time_id}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -49,8 +49,6 @@ class FluxPrompter(BasePrompter):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
).input_ids.to(device)
|
).input_ids.to(device)
|
||||||
prompt_emb = text_encoder(input_ids)
|
prompt_emb = text_encoder(input_ids)
|
||||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
|
||||||
|
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user