mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Merge pull request #207 from modelscope/flux-schnell
support flux-schnell
This commit is contained in:
@@ -76,6 +76,7 @@ model_loader_configs = [
|
||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
|
||||
@@ -281,11 +281,11 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
@@ -362,9 +362,9 @@ class FluxDiT(torch.nn.Module):
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -589,5 +589,7 @@ class FluxDiTStateDictConverter:
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user