From 02fcfd530f894d4f9ffd59374aff062ebfc27e58 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 15 Nov 2024 14:20:39 +0800 Subject: [PATCH 1/2] support sd3.5 medium and large-turbo --- diffsynth/configs/model_config.py | 15 +- diffsynth/models/sd3_dit.py | 140 ++++++++++++++++-- .../image_synthesis/sd35_text_to_image.py | 2 +- 3 files changed, 145 insertions(+), 12 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b4ec524..b5057d6 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -90,7 +90,7 @@ model_loader_configs = [ (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), - # (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai") + (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -590,6 +590,18 @@ preset_models_on_modelscope = { ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"), ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"), ], + "StableDiffusion3.5-medium": [ + ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"), + ], + "StableDiffusion3.5-large-turbo": [ + ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"), + ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"), + ], } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -643,4 +655,5 @@ Preset_model_id: TypeAlias = Literal[ "Annotators:Normal", "Annotators:Openpose", "StableDiffusion3.5-large", + "StableDiffusion3.5-medium", ] diff --git a/diffsynth/models/sd3_dit.py b/diffsynth/models/sd3_dit.py index 5b44068..6168088 100644 --- a/diffsynth/models/sd3_dit.py +++ b/diffsynth/models/sd3_dit.py @@ -62,10 +62,11 @@ class TimestepEmbeddings(torch.nn.Module): class AdaLayerNorm(torch.nn.Module): - def __init__(self, dim, single=False): + def __init__(self, dim, single=False, dual=False): super().__init__() self.single = single - self.linear = torch.nn.Linear(dim, dim * (2 if single else 6)) + self.dual = dual + self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb): @@ -74,6 +75,12 @@ class AdaLayerNorm(torch.nn.Module): scale, shift = emb.unsqueeze(1).chunk(2, dim=2) x = self.norm(x) * (1 + scale) + shift return x + elif self.dual: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) + norm_x = self.norm(x) + x = norm_x * (1 + scale_msa) + shift_msa + norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) x = self.norm(x) * (1 + scale_msa) + shift_msa @@ -138,16 +145,58 @@ class JointAttention(torch.nn.Module): else: hidden_states_b = self.b_to_out(hidden_states_b) return hidden_states_a, hidden_states_b + +class SingleAttention(torch.nn.Module): + def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim -class JointTransformerBlock(torch.nn.Module): + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + self.a_to_out = torch.nn.Linear(dim_a, dim_a) + + if use_rms_norm: + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + else: + self.norm_q_a = None + self.norm_k_a = None + + + def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k): + batch_size = hidden_states.shape[0] + qkv = to_qkv(hidden_states) + qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q, k, v = qkv.chunk(3, dim=1) + if norm_q is not None: + q = norm_q(q) + if norm_k is not None: + k = norm_k(k) + return q, k, v + + + def forward(self, hidden_states_a): + batch_size = hidden_states_a.shape[0] + q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.a_to_out(hidden_states) + return hidden_states + + + +class DualTransformerBlock(torch.nn.Module): def __init__(self, dim, num_attention_heads, use_rms_norm=False): super().__init__() - self.norm1_a = AdaLayerNorm(dim) + self.norm1_a = AdaLayerNorm(dim, dual=True) self.norm1_b = AdaLayerNorm(dim) self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) + self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_a = torch.nn.Sequential( @@ -165,7 +214,7 @@ class JointTransformerBlock(torch.nn.Module): def forward(self, hidden_states_a, hidden_states_b, temb): - norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention @@ -173,6 +222,58 @@ class JointTransformerBlock(torch.nn.Module): # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2) + norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a + hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) + + # Part B + hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b + norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b + hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) + + return hidden_states_a, hidden_states_b + + + +class JointTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False): + super().__init__() + self.norm1_a = AdaLayerNorm(dim, dual=dual) + self.norm1_b = AdaLayerNorm(dim) + + self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) + if dual: + self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm) + + self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_a = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_b = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def forward(self, hidden_states_a, hidden_states_b, temb): + if self.norm1_a.dual: + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb) + else: + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + + # Attention + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b) + + # Part A + hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + if self.norm1_a.dual: + hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2) norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) @@ -218,13 +319,14 @@ class JointTransformerFinalBlock(torch.nn.Module): class SD3DiT(torch.nn.Module): - def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False): + def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192): super().__init__() - self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=192) + self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size) self.time_embedder = TimestepEmbeddings(256, embed_dim) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim)) self.context_embedder = torch.nn.Linear(4096, embed_dim) - self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1)] + self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)] + + [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)] + [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)]) self.norm_out = AdaLayerNorm(embed_dim, single=True) self.proj_out = torch.nn.Linear(embed_dim, 64) @@ -286,7 +388,17 @@ class SD3DiTStateDictConverter: while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict: num_layers -= 1 use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict - return {"embed_dim": embed_dim, "num_layers": num_layers, "use_rms_norm": use_rms_norm} + num_dual_blocks = 0 + while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict: + num_dual_blocks += 1 + pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1] + return { + "embed_dim": embed_dim, + "num_layers": num_layers, + "use_rms_norm": use_rms_norm, + "num_dual_blocks": num_dual_blocks, + "pos_embed_max_size": pos_embed_max_size + } def from_diffusers(self, state_dict): rename_dict = { @@ -402,13 +514,21 @@ class SD3DiTStateDictConverter: f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight", f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight", f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight", + + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight", + f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias", }) state_dict_ = {} for name in state_dict: if name in rename_dict: param = state_dict[name] if name == "model.diffusion_model.pos_embed": - param = param.reshape((1, 192, 192, param.shape[-1])) + pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4) + param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1])) if isinstance(rename_dict[name], str): state_dict_[rename_dict[name]] = param else: diff --git a/examples/image_synthesis/sd35_text_to_image.py b/examples/image_synthesis/sd35_text_to_image.py index 37f2da8..94f59e0 100644 --- a/examples/image_synthesis/sd35_text_to_image.py +++ b/examples/image_synthesis/sd35_text_to_image.py @@ -2,7 +2,7 @@ from diffsynth import ModelManager, SD3ImagePipeline import torch -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["StableDiffusion3.5-large"]) +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", model_id_list=["StableDiffusion3.5-large"]) pipe = SD3ImagePipeline.from_model_manager(model_manager) prompt = "a full body photo of a beautiful Asian girl. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." From 9cb4aa16eba019f5235352db1a9012e673f18d58 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 20 Nov 2024 09:51:31 +0800 Subject: [PATCH 2/2] fix cogvideo height width checker --- diffsynth/pipelines/cog_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/pipelines/cog_video.py b/diffsynth/pipelines/cog_video.py index 4b7f336..f42d295 100644 --- a/diffsynth/pipelines/cog_video.py +++ b/diffsynth/pipelines/cog_video.py @@ -13,7 +13,7 @@ from einops import rearrange class CogVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16): - super().__init__(device=device, torch_dtype=torch_dtype) + super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16) self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction") self.prompter = CogPrompter() # models