mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
support SD3 LoRA
This commit is contained in:
@@ -567,7 +567,7 @@ class ModelManager:
|
||||
if component == "sd3_text_encoder_3":
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
|
||||
continue
|
||||
elif component == "sd3_text_encoder_1":
|
||||
if component == "sd3_text_encoder_1":
|
||||
# Add additional token embeddings to text encoder
|
||||
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
|
||||
for keyword in self.textual_inversion_dict:
|
||||
|
||||
@@ -199,7 +199,7 @@ class SD3DiT(torch.nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64):
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
@@ -207,8 +207,22 @@ class SD3DiT(torch.nn.Module):
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.pos_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
|
||||
@@ -69,7 +69,7 @@ class SD3Prompter(Prompter):
|
||||
|
||||
# T5
|
||||
if text_encoder_3 is None:
|
||||
prompt_emb_3 = torch.zeros((1, 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
else:
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||
|
||||
@@ -124,6 +124,13 @@ but make sure there is a correlation between the input and output.\n\
|
||||
return prompt
|
||||
|
||||
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
|
||||
if isinstance(prompt, list):
|
||||
prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt]
|
||||
if require_pure_prompt:
|
||||
prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt]
|
||||
return prompt, pure_prompt
|
||||
else:
|
||||
return prompt
|
||||
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
|
||||
@@ -40,3 +40,8 @@ class FlowMatchScheduler():
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
Reference in New Issue
Block a user