diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b4ec524..b5057d6 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -90,7 +90,7 @@ model_loader_configs = [ (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), - # (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai") + (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -590,6 +590,18 @@ preset_models_on_modelscope = { ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"), ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"), ], + "StableDiffusion3.5-medium": [ + ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"), + ], + "StableDiffusion3.5-large-turbo": [ + ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"), + ], } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -643,4 +655,5 @@ Preset_model_id: TypeAlias = Literal[ "Annotators:Normal", "Annotators:Openpose", "StableDiffusion3.5-large", + "StableDiffusion3.5-medium", ] diff --git a/diffsynth/models/sd3_dit.py b/diffsynth/models/sd3_dit.py index 5b44068..6168088 100644 --- a/diffsynth/models/sd3_dit.py +++ b/diffsynth/models/sd3_dit.py @@ -62,10 +62,11 @@ class TimestepEmbeddings(torch.nn.Module): class AdaLayerNorm(torch.nn.Module): - def __init__(self, dim, single=False): + def __init__(self, dim, single=False, dual=False): super().__init__() self.single = single - self.linear = torch.nn.Linear(dim, dim * (2 if single else 6)) + self.dual = dual + self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb): @@ -74,6 +75,12 @@ class AdaLayerNorm(torch.nn.Module): scale, shift = emb.unsqueeze(1).chunk(2, dim=2) x = self.norm(x) * (1 + scale) + shift return x + elif self.dual: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) + norm_x = self.norm(x) + x = norm_x * (1 + scale_msa) + shift_msa + norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) x = self.norm(x) * (1 + scale_msa) + shift_msa @@ -138,16 +145,58 @@ class JointAttention(torch.nn.Module): else: hidden_states_b = self.b_to_out(hidden_states_b) return hidden_states_a, hidden_states_b + +class SingleAttention(torch.nn.Module): + def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim -class JointTransformerBlock(torch.nn.Module): + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + self.a_to_out = torch.nn.Linear(dim_a, dim_a) + + if use_rms_norm: + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + else: + self.norm_q_a = None + self.norm_k_a = None + + + def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k): + batch_size = hidden_states.shape[0] + qkv = to_qkv(hidden_states) + qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q, k, v = qkv.chunk(3, dim=1) + if norm_q is not None: + q = norm_q(q) + if norm_k is not None: + k = norm_k(k) + return q, k, v + + + def forward(self, hidden_states_a): + batch_size = hidden_states_a.shape[0] + q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.a_to_out(hidden_states) + return hidden_states + + + +class DualTransformerBlock(torch.nn.Module): def __init__(self, dim, num_attention_heads, use_rms_norm=False): super().__init__() - self.norm1_a = AdaLayerNorm(dim) + self.norm1_a = AdaLayerNorm(dim, dual=True) self.norm1_b = AdaLayerNorm(dim) self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) + self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_a = torch.nn.Sequential( @@ -165,7 +214,7 @@ class JointTransformerBlock(torch.nn.Module): def forward(self, hidden_states_a, hidden_states_b, temb): - norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention @@ -173,6 +222,58 @@ class JointTransformerBlock(torch.nn.Module): # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2) + norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a + hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) + + # Part B + hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b + norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b + hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) + + return hidden_states_a, hidden_states_b + + + +class JointTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False): + super().__init__() + self.norm1_a = AdaLayerNorm(dim, dual=dual) + self.norm1_b = AdaLayerNorm(dim) + + self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) + if dual: + self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) + + self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_a = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_b = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def forward(self, hidden_states_a, hidden_states_b, temb): + if self.norm1_a.dual: + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb) + else: + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + + # Attention + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b) + + # Part A + hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + if self.norm1_a.dual: + hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2) norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) @@ -218,13 +319,14 @@ class JointTransformerFinalBlock(torch.nn.Module): class SD3DiT(torch.nn.Module): - def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False): + def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192): super().__init__() - self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=192) + self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size) self.time_embedder = TimestepEmbeddings(256, embed_dim) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim)) self.context_embedder = torch.nn.Linear(4096, embed_dim) - self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1)] + self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)] + + [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)] + [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)]) self.norm_out = AdaLayerNorm(embed_dim, single=True) self.proj_out = torch.nn.Linear(embed_dim, 64) @@ -286,7 +388,17 @@ class SD3DiTStateDictConverter: while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict: num_layers -= 1 use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict - return {"embed_dim": embed_dim, "num_layers": num_layers, "use_rms_norm": use_rms_norm} + num_dual_blocks = 0 + while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict: + num_dual_blocks += 1 + pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1] + return { + "embed_dim": embed_dim, + "num_layers": num_layers, + "use_rms_norm": use_rms_norm, + "num_dual_blocks": num_dual_blocks, + "pos_embed_max_size": pos_embed_max_size + } def from_diffusers(self, state_dict): rename_dict = { @@ -402,13 +514,21 @@ class SD3DiTStateDictConverter: f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight", f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight", f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight", + + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias", }) state_dict_ = {} for name in state_dict: if name in rename_dict: param = state_dict[name] if name == "model.diffusion_model.pos_embed": - param = param.reshape((1, 192, 192, param.shape[-1])) + pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4) + param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1])) if isinstance(rename_dict[name], str): state_dict_[rename_dict[name]] = param else: diff --git a/diffsynth/pipelines/cog_video.py b/diffsynth/pipelines/cog_video.py index 4b7f336..f42d295 100644 --- a/diffsynth/pipelines/cog_video.py +++ b/diffsynth/pipelines/cog_video.py @@ -13,7 +13,7 @@ from einops import rearrange class CogVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16): - super().__init__(device=device, torch_dtype=torch_dtype) + super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16) self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction") self.prompter = CogPrompter() # models diff --git a/examples/image_synthesis/sd35_text_to_image.py b/examples/image_synthesis/sd35_text_to_image.py index 37f2da8..94f59e0 100644 --- a/examples/image_synthesis/sd35_text_to_image.py +++ b/examples/image_synthesis/sd35_text_to_image.py @@ -2,7 +2,7 @@ from diffsynth import ModelManager, SD3ImagePipeline import torch -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["StableDiffusion3.5-large"]) +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", model_id_list=["StableDiffusion3.5-large"]) pipe = SD3ImagePipeline.from_model_manager(model_manager) prompt = "a full body photo of a beautiful Asian girl. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."