diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 172400b..24c2b7c 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -255,12 +255,36 @@ wan_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter", }, { - # ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors") + # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors") "model_hash": "06be60f3a4526586d8431cd038a71486", "model_name": "wans2v_audio_encoder", "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter", - } + }, ] -MODEL_CONFIGS = qwen_image_series + wan_series +flux_series = [ + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors") + "model_hash": "a29710fea6dddb0314663ee823598e50", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78", + "model_name": "flux_text_encoder_clip", + "model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter", + }, + { + # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors") + "model_hash": "22540b49eaedbc2f2784b2091a234c7c", + "model_name": "flux_text_encoder_t5", + "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter", + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py new file mode 100644 index 0000000..85fccd7 --- /dev/null +++ b/diffsynth/models/flux_controlnet.py @@ -0,0 +1,331 @@ +import torch +from einops import rearrange, repeat +from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm +from .utils import hash_state_dict_keys, init_weights_on_device + + + +class FluxControlNet(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(64, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)]) + + self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)]) + self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)]) + + self.mode_dict = mode_dict + self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None + self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072) + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states): + if len(res_stack) == 0: + return [torch.zeros_like(hidden_states)] * num_blocks + interval = (num_blocks + len(res_stack) - 1) // len(res_stack) + aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)] + return aligned_res_stack + + + def forward( + self, + hidden_states, + controlnet_conditioning, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + processor_id=None, + tiled=False, tile_size=128, tile_stride=64, + **kwargs + ): + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) + if self.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) + prompt_emb = self.context_embedder(prompt_emb) + if self.controlnet_mode_embedder is not None: # Different from FluxDiT + processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int) + processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device) + prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1) + text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + + hidden_states = self.patchify(hidden_states) + hidden_states = self.x_embedder(hidden_states) + controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT + + controlnet_res_stack = [] + for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_res_stack.append(controlnet_block(hidden_states)) + + controlnet_single_res_stack = [] + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:])) + + controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:]) + controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:]) + + return controlnet_res_stack, controlnet_single_res_stack + + + @staticmethod + def state_dict_converter(): + return FluxControlNetStateDictConverter() + + def quantize(self): + def cast_to(weight, dtype=None, device=None, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + def cast_weight(s, input=None, dtype=None, device=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if device is None: + device = input.device + weight = cast_to(s.weight, dtype, device) + return weight + + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + bias = None + weight = cast_to(s.weight, dtype, device) + bias = cast_to(s.bias, bias_dtype, device) + return weight, bias + + class quantized_layer: + class QLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight,bias= cast_bias_weight(self,input) + return torch.nn.functional.linear(input,weight,bias) + + class QRMSNorm(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self,hidden_states,**kwargs): + weight= cast_weight(self.module,hidden_states) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) + hidden_states = hidden_states.to(input_dtype) * weight + return hidden_states + + class QEmbedding(torch.nn.Embedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self,input,**kwargs): + weight= cast_weight(self,input) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def replace_layer(model): + for name, module in model.named_children(): + if isinstance(module,quantized_layer.QRMSNorm): + continue + if isinstance(module, torch.nn.Linear): + with init_weights_on_device(): + new_layer = quantized_layer.QLinear(module.in_features,module.out_features) + new_layer.weight = module.weight + if module.bias is not None: + new_layer.bias = module.bias + setattr(model, name, new_layer) + elif isinstance(module, RMSNorm): + if hasattr(module,"quantized"): + continue + module.quantized= True + new_layer = quantized_layer.QRMSNorm(module) + setattr(model, name, new_layer) + elif isinstance(module,torch.nn.Embedding): + rows, cols = module.weight.shape + new_layer = quantized_layer.QEmbedding( + num_embeddings=rows, + embedding_dim=cols, + _weight=module.weight, + # _freeze=module.freeze, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse) + setattr(model, name, new_layer) + else: + replace_layer(module) + + replace_layer(self) + + + +class FluxControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + hash_value = hash_state_dict_keys(state_dict) + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + if hash_value == "78d18b9101345ff695f312e7e62538c0": + extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}} + elif hash_value == "b001c89139b5f053c715fe772362dd2a": + extra_kwargs = {"num_single_blocks": 0} + elif hash_value == "52357cb26250681367488a8954c271e8": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4} + elif hash_value == "0cfd1740758423a2a854d67c136d1e8c": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1} + elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10} + elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0} + else: + extra_kwargs = {} + return state_dict_, extra_kwargs + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py new file mode 100644 index 0000000..51a6e7f --- /dev/null +++ b/diffsynth/models/flux_dit.py @@ -0,0 +1,395 @@ +import torch +from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm +from einops import rearrange + + +def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size, num_tokens = hidden_states.shape[0:2] + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + +class RoPEEmbedding(torch.nn.Module): + def __init__(self, dim, theta, axes_dim): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) + return emb.unsqueeze(1) + + + +class FluxJointAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.only_out_a = only_out_a + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + self.norm_q_b = RMSNorm(head_dim, eps=1e-6) + self.norm_k_b = RMSNorm(head_dim, eps=1e-6) + + self.a_to_out = torch.nn.Linear(dim_a, dim_a) + if not only_out_a: + self.b_to_out = torch.nn.Linear(dim_b, dim_b) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states_a.shape[0] + + # Part A + qkv_a = self.a_to_qkv(hidden_states_a) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v_a = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + # Part B + qkv_b = self.b_to_qkv(hidden_states_b) + qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_b, k_b, v_b = qkv_b.chunk(3, dim=1) + q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b) + + q = torch.concat([q_b, q_a], dim=2) + k = torch.concat([k_b, k_a], dim=2) + v = torch.concat([v_b, v_a], dim=2) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] + if ipadapter_kwargs_list is not None: + hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list) + hidden_states_a = self.a_to_out(hidden_states_a) + if self.only_out_a: + return hidden_states_a + else: + hidden_states_b = self.b_to_out(hidden_states_b) + return hidden_states_a, hidden_states_b + + + +class FluxJointTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.norm1_a = AdaLayerNorm(dim) + self.norm1_b = AdaLayerNorm(dim) + + self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads) + + self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_a = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_b = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) + norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) + + # Attention + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + + # Part A + hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a + norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a + hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a) + + # Part B + hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b + norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b + hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) + + return hidden_states_a, hidden_states_b + + + +class FluxSingleAttention(torch.nn.Module): + def __init__(self, dim_a, dim_b, num_heads, head_dim): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3) + + self.norm_q_a = RMSNorm(head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(head_dim, eps=1e-6) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def forward(self, hidden_states, image_rotary_emb): + batch_size = hidden_states.shape[0] + + qkv_a = self.a_to_qkv(hidden_states) + qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q_a, k_a, v = qkv_a.chunk(3, dim=1) + q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a) + + q, k = self.apply_rope(q_a, k_a, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + return hidden_states + + + +class AdaLayerNormSingle(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, 3 * dim, bias=True) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + + +class FluxSingleTransformerBlock(torch.nn.Module): + def __init__(self, dim, num_attention_heads): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = dim // num_attention_heads + self.dim = dim + + self.norm = AdaLayerNormSingle(dim) + self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4)) + self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6) + self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6) + + self.proj_out = torch.nn.Linear(dim * 5, dim) + + + def apply_rope(self, xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + batch_size = hidden_states.shape[0] + + qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) + q, k, v = qkv.chunk(3, dim=1) + q, k = self.norm_q_a(q), self.norm_k_a(k) + + q, k = self.apply_rope(q, k, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + if ipadapter_kwargs_list is not None: + hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list) + return hidden_states + + + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): + residual = hidden_states_a + norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) + hidden_states_a = self.to_qkv_mlp(norm_hidden_states) + attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] + + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) + mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") + + hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) + hidden_states_a = residual + hidden_states_a + + return hidden_states_a, hidden_states_b + + + +class AdaLayerNormContinuous(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(dim, dim * 2, bias=True) + self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False) + + def forward(self, x, conditioning): + emb = self.linear(self.silu(conditioning)) + shift, scale = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None] + shift[:, None] + return x + + + +class FluxDiT(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) + + self.final_norm_out = AdaLayerNormContinuous(3072) + self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def unpatchify(self, hidden_states, height, width): + hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) + return hidden_states + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask + + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim): + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + + def forward( + self, + hidden_states, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, + use_gradient_checkpointing=False, + **kwargs + ): + # (Deprecated) The real forward is in `pipelines.flux_image`. + return None diff --git a/diffsynth/models/flux_infiniteyou.py b/diffsynth/models/flux_infiniteyou.py new file mode 100644 index 0000000..861538a --- /dev/null +++ b/diffsynth/models/flux_infiniteyou.py @@ -0,0 +1,129 @@ +import math +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class InfiniteYouImageProjector(nn.Module): + + def __init__( + self, + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=8, + embedding_dim=512, + output_dim=4096, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ])) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + latents = latents.to(dtype=x.dtype, device=x.device) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + @staticmethod + def state_dict_converter(): + return FluxInfiniteYouImageProjectorStateDictConverter() + + +class FluxInfiniteYouImageProjectorStateDictConverter: + + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict['image_proj'] diff --git a/diffsynth/models/flux_ipadapter.py b/diffsynth/models/flux_ipadapter.py new file mode 100644 index 0000000..575c752 --- /dev/null +++ b/diffsynth/models/flux_ipadapter.py @@ -0,0 +1,94 @@ +from .svd_image_encoder import SVDImageEncoder +from .sd3_dit import RMSNorm +from transformers import CLIPImageProcessor +import torch + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +class IpAdapterModule(torch.nn.Module): + def __init__(self, num_attention_heads, attention_head_dim, input_dim): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + output_dim = num_attention_heads * attention_head_dim + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False) + + + def forward(self, hidden_states): + batch_size = hidden_states.shape[0] + # ip_k + ip_k = self.to_k_ip(hidden_states) + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_k = self.norm_added_k(ip_k) + # ip_v + ip_v = self.to_v_ip(hidden_states) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + return ip_k, ip_v + + +class FluxIpAdapter(torch.nn.Module): + def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57): + super().__init__() + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)]) + self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens) + self.set_adapter() + + def set_adapter(self): + self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for block_id in self.call_block_id: + ipadapter_id = self.call_block_id[block_id] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + ip_kv_dict[block_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + @staticmethod + def state_dict_converter(): + return FluxIpAdapterStateDictConverter() + + +class FluxIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py new file mode 100644 index 0000000..695640a --- /dev/null +++ b/diffsynth/models/flux_lora_encoder.py @@ -0,0 +1,111 @@ +import torch +from .sd_text_encoder import CLIPEncoderLayer + + +class LoRALayerBlock(torch.nn.Module): + def __init__(self, L, dim_in, dim_out): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) + self.layer_norm = torch.nn.LayerNorm(dim_out) + + def forward(self, lora_A, lora_B): + x = self.x @ lora_A.T @ lora_B.T + x = self.layer_norm(x) + return x + + +class LoRAEmbedder(torch.nn.Module): + def __init__(self, lora_patterns=None, L=1, out_dim=2048): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1]) + self.model_dict = torch.nn.ModuleDict(model_dict) + + proj_dict = {} + for lora_pattern in lora_patterns: + layer_type, dim = lora_pattern["type"], lora_pattern["dim"] + if layer_type not in proj_dict: + proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim) + self.proj_dict = torch.nn.ModuleDict(proj_dict) + + self.lora_patterns = lora_patterns + + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), + "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix], + "type": suffix, + }) + return lora_patterns + + def forward(self, lora): + lora_emb = [] + for lora_pattern in self.lora_patterns: + name, layer_type = lora_pattern["name"], lora_pattern["type"] + lora_A = lora[name + ".lora_A.default.weight"] + lora_B = lora[name + ".lora_B.default.weight"] + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) + lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) + lora_emb.append(lora_out) + lora_emb = torch.concat(lora_emb, dim=1) + return lora_emb + + +class FluxLoRAEncoder(torch.nn.Module): + def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1): + super().__init__() + self.num_embeds_per_lora = num_embeds_per_lora + # embedder + self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)]) + + # special embedding + self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim)) + self.num_special_embeds = num_special_embeds + + # final layer + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + self.final_linear = torch.nn.Linear(embed_dim, embed_dim) + + def forward(self, lora): + lora_embeds = self.embedder(lora) + special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device) + embeds = torch.concat([special_embeds, lora_embeds], dim=1) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds) + embeds = embeds[:, :self.num_special_embeds] + embeds = self.final_layer_norm(embeds) + embeds = self.final_linear(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return FluxLoRAEncoderStateDictConverter() + + +class FluxLoRAEncoderStateDictConverter: + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/flux_text_encoder_clip.py b/diffsynth/models/flux_text_encoder_clip.py new file mode 100644 index 0000000..1425423 --- /dev/null +++ b/diffsynth/models/flux_text_encoder_clip.py @@ -0,0 +1,112 @@ +import torch + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FluxTextEncoderClip(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=2, extra_mask=None): + embeds = self.token_embedding(input_ids) + embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device) + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + if extra_mask is not None: + attn_mask[:, extra_mask[0]==0] = float("-inf") + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + hidden_states = embeds + embeds = self.final_layer_norm(embeds) + pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] + return pooled_embeds, hidden_states diff --git a/diffsynth/models/flux_text_encoder_t5.py b/diffsynth/models/flux_text_encoder_t5.py new file mode 100644 index 0000000..ee72e4a --- /dev/null +++ b/diffsynth/models/flux_text_encoder_t5.py @@ -0,0 +1,43 @@ +import torch +from transformers import T5EncoderModel, T5Config + + +class FluxTextEncoderT5(T5EncoderModel): + def __init__(self): + config = T5Config(**{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "dtype": "bfloat16", + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": True, + "is_gated_act": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": True, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": False, + "transformers_version": "4.57.1", + "use_cache": True, + "vocab_size": 32128 + }) + super().__init__(config) + + def forward(self, input_ids): + outputs = super().forward(input_ids=input_ids) + prompt_emb = outputs.last_hidden_state + return prompt_emb diff --git a/diffsynth/models/flux_vae.py b/diffsynth/models/flux_vae.py new file mode 100644 index 0000000..cbb7038 --- /dev/null +++ b/diffsynth/models/flux_vae.py @@ -0,0 +1,394 @@ +import torch +from einops import rearrange, repeat + + +class TileWorker: + def __init__(self): + pass + + + def mask(self, height, width, border_width): + # Create a mask with shape (height, width). + # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. + x = torch.arange(height).repeat(width, 1).T + y = torch.arange(width).repeat(height, 1) + mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values + mask = (mask / border_width).clip(0, 1) + return mask + + + def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): + # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) + batch_size, channel, _, _ = model_input.shape + model_input = model_input.to(device=tile_device, dtype=tile_dtype) + unfold_operator = torch.nn.Unfold( + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + model_input = unfold_operator(model_input) + model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) + + return model_input + + + def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): + # Call y=forward_fn(x) for each tile + tile_num = model_input.shape[-1] + model_output_stack = [] + + for tile_id in range(0, tile_num, tile_batch_size): + + # process input + tile_id_ = min(tile_id + tile_batch_size, tile_num) + x = model_input[:, :, :, :, tile_id: tile_id_] + x = x.to(device=inference_device, dtype=inference_dtype) + x = rearrange(x, "b c h w n -> (n b) c h w") + + # process output + y = forward_fn(x) + y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) + y = y.to(device=tile_device, dtype=tile_dtype) + model_output_stack.append(y) + + model_output = torch.concat(model_output_stack, dim=-1) + return model_output + + + def io_scale(self, model_output, tile_size): + # Determine the size modification happened in forward_fn + # We only consider the same scale on height and width. + io_scale = model_output.shape[2] / tile_size + return io_scale + + + def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): + # The reversed function of tile + mask = self.mask(tile_size, tile_size, border_width) + mask = mask.to(device=tile_device, dtype=tile_dtype) + mask = rearrange(mask, "h w -> 1 1 h w 1") + model_output = model_output * mask + + fold_operator = torch.nn.Fold( + output_size=(height, width), + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) + model_output = rearrange(model_output, "b c h w n -> b (c h w) n") + model_output = fold_operator(model_output) / fold_operator(mask) + + return model_output + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + inference_device, inference_dtype = model_input.device, model_input.dtype + height, width = model_input.shape[2], model_input.shape[3] + border_width = int(tile_stride*0.5) if border_width is None else border_width + + # tile + model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) + + # inference + model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) + + # resize + io_scale = self.io_scale(model_output, tile_size) + height, width = int(height*io_scale), int(width*io_scale) + tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) + border_width = int(border_width*io_scale) + + # untile + model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) + + # Done! + model_output = model_output.to(device=inference_device, dtype=inference_dtype) + return model_output + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + +class VAEAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + + self.transformer_blocks = torch.nn.ModuleList([ + Attention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class ResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x = hidden_states + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb)[:, :, None, None] + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +class UpSampler(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class DownSampler(torch.nn.Module): + def __init__(self, channels, padding=1, extra_padding=False): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) + self.extra_padding = extra_padding + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + if self.extra_padding: + hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class SD3VAEDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x + + self.blocks = torch.nn.ModuleList([ + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + UpSampler(256), + # UpDecoderBlock2D + ResnetBlock(256, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = sample / self.scaling_factor + self.shift_factor + hidden_states = self.conv_in(hidden_states) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class SD3VAEEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.3611 + self.shift_factor = 0.1159 + self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # DownEncoderBlock2D + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + DownSampler(128, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(128, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + DownSampler(256, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(256, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + DownSampler(512, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = self.conv_in(sample) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + hidden_states = hidden_states[:, :16] + hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor + + return hidden_states + + def encode_video(self, sample, batch_size=8): + B = sample.shape[0] + hidden_states = [] + + for i in range(0, sample.shape[2], batch_size): + + j = min(i + batch_size, sample.shape[2]) + sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W") + + hidden_states_batch = self(sample_batch) + hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B) + + hidden_states.append(hidden_states_batch) + + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py new file mode 100644 index 0000000..6981344 --- /dev/null +++ b/diffsynth/models/flux_value_control.py @@ -0,0 +1,60 @@ +import torch +from diffsynth.models.svd_unet import TemporalTimesteps + + +class MultiValueEncoder(torch.nn.Module): + def __init__(self, encoders=()): + super().__init__() + self.encoders = torch.nn.ModuleList(encoders) + + def __call__(self, values, dtype): + emb = [] + for encoder, value in zip(self.encoders, values): + if value is not None: + value = value.unsqueeze(0) + emb.append(encoder(value, dtype)) + emb = torch.concat(emb, dim=0) + return emb + + +class SingleValueEncoder(torch.nn.Module): + def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None): + super().__init__() + self.prefer_len = prefer_len + self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) + self.prefer_value_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.positional_embedding = torch.nn.Parameter( + torch.randn(self.prefer_len, dim_out) + ) + self._initialize_weights() + + def _initialize_weights(self): + last_linear = self.prefer_value_embedder[-1] + torch.nn.init.zeros_(last_linear.weight) + torch.nn.init.zeros_(last_linear.bias) + + def forward(self, value, dtype): + value = value * 1000 + emb = self.prefer_proj(value).to(dtype) + emb = self.prefer_value_embedder(emb).squeeze(0) + base_embeddings = emb.expand(self.prefer_len, -1) + positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device) + learned_embeddings = base_embeddings + positional_embedding + return learned_embeddings + + @staticmethod + def state_dict_converter(): + return SingleValueEncoderStateDictConverter() + + +class SingleValueEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py index 7cfc134..b79e277 100644 --- a/diffsynth/models/model_loader.py +++ b/diffsynth/models/model_loader.py @@ -77,7 +77,7 @@ class ModelPool: print(f"Loaded model: {json.dumps(model_info, indent=4)}") loaded = True if not loaded: - raise ValueError(f"Cannot detect the model type. File: {path}") + raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}.") def fetch_model(self, model_name, index=None): fetched_models = [] diff --git a/diffsynth/utils/state_dict_converters/flux_dit.py b/diffsynth/utils/state_dict_converters/flux_dit.py new file mode 100644 index 0000000..3250a22 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_dit.py @@ -0,0 +1,77 @@ +def FluxDiTStateDictConverter(state_dict): + rename_dict = { + "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", + "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", + "time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias", + "time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight", + "txt_in.bias": "context_embedder.bias", + "txt_in.weight": "context_embedder.weight", + "vector_in.in_layer.bias": "pooled_text_embedder.0.bias", + "vector_in.in_layer.weight": "pooled_text_embedder.0.weight", + "vector_in.out_layer.bias": "pooled_text_embedder.2.bias", + "vector_in.out_layer.weight": "pooled_text_embedder.2.weight", + "final_layer.linear.bias": "final_proj_out.bias", + "final_layer.linear.weight": "final_proj_out.weight", + "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias", + "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight", + "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias", + "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight", + "img_in.bias": "x_embedder.bias", + "img_in.weight": "x_embedder.weight", + "final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight", + "final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias", + } + suffix_rename_dict = { + "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight", + "img_attn.norm.query_norm.scale": "attn.norm_q_a.weight", + "img_attn.proj.bias": "attn.a_to_out.bias", + "img_attn.proj.weight": "attn.a_to_out.weight", + "img_attn.qkv.bias": "attn.a_to_qkv.bias", + "img_attn.qkv.weight": "attn.a_to_qkv.weight", + "img_mlp.0.bias": "ff_a.0.bias", + "img_mlp.0.weight": "ff_a.0.weight", + "img_mlp.2.bias": "ff_a.2.bias", + "img_mlp.2.weight": "ff_a.2.weight", + "img_mod.lin.bias": "norm1_a.linear.bias", + "img_mod.lin.weight": "norm1_a.linear.weight", + "txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight", + "txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight", + "txt_attn.proj.bias": "attn.b_to_out.bias", + "txt_attn.proj.weight": "attn.b_to_out.weight", + "txt_attn.qkv.bias": "attn.b_to_qkv.bias", + "txt_attn.qkv.weight": "attn.b_to_qkv.weight", + "txt_mlp.0.bias": "ff_b.0.bias", + "txt_mlp.0.weight": "ff_b.0.weight", + "txt_mlp.2.bias": "ff_b.2.bias", + "txt_mlp.2.weight": "ff_b.2.weight", + "txt_mod.lin.bias": "norm1_b.linear.bias", + "txt_mod.lin.weight": "norm1_b.linear.weight", + + "linear1.bias": "to_qkv_mlp.bias", + "linear1.weight": "to_qkv_mlp.weight", + "linear2.bias": "proj_out.bias", + "linear2.weight": "proj_out.weight", + "modulation.lin.bias": "norm.linear.bias", + "modulation.lin.weight": "norm.linear.weight", + "norm.key_norm.scale": "norm_k_a.weight", + "norm.query_norm.scale": "norm_q_a.weight", + } + state_dict_ = {} + for name in state_dict: + original_name = name + if name.startswith("model.diffusion_model."): + name = name[len("model.diffusion_model."):] + names = name.split(".") + if name in rename_dict: + rename = rename_dict[name] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "double_blocks": + rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + elif names[0] == "single_blocks": + if ".".join(names[2:]) in suffix_rename_dict: + rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] + state_dict_[rename] = state_dict[original_name] + else: + pass + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py b/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py new file mode 100644 index 0000000..aa018aa --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_text_encoder_clip.py @@ -0,0 +1,31 @@ +def FluxTextEncoderClipStateDictConverter(state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias", + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py b/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py new file mode 100644 index 0000000..d35eb83 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_text_encoder_t5.py @@ -0,0 +1,4 @@ +def FluxTextEncoderT5StateDictConverter(state_dict): + state_dict_ = {i: state_dict[i] for i in state_dict} + state_dict_["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + return state_dict_ diff --git a/examples/test/run.py b/examples/test/run.py index 0f84bb4..f140d6e 100644 --- a/examples/test/run.py +++ b/examples/test/run.py @@ -71,13 +71,13 @@ if __name__ == "__main__": # run_inference("examples/qwen_image/model_inference_low_vram") # run_inference("examples/qwen_image/model_training/validate_full") # run_inference("examples/qwen_image/model_training/validate_lora") - run_train_single_GPU("examples/wanvideo/model_inference_low_vram") - move_files("video_", "data/output/model_inference_low_vram") - run_train_single_GPU("examples/wanvideo/model_inference") - move_files("video_", "data/output/model_inference") - run_train_single_GPU("examples/wanvideo/model_training/lora") - run_train_single_GPU("examples/wanvideo/model_training/validate_lora") - move_files("video_", "data/output/validate_lora") - run_train_multi_GPU("examples/wanvideo/model_training/full") - run_train_single_GPU("examples/wanvideo/model_training/validate_full") + # run_train_single_GPU("examples/wanvideo/model_inference_low_vram") + # move_files("video_", "data/output/model_inference_low_vram") + # run_train_single_GPU("examples/wanvideo/model_inference") + # move_files("video_", "data/output/model_inference") + # run_train_single_GPU("examples/wanvideo/model_training/lora") + # run_train_single_GPU("examples/wanvideo/model_training/validate_lora") + # move_files("video_", "data/output/validate_lora") + # run_train_multi_GPU("examples/wanvideo/model_training/full") + run_train_multi_GPU("examples/wanvideo/model_training/validate_full") move_files("video_", "data/output/validate_full") diff --git a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh index 4352a05..fe85ca8 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh @@ -9,7 +9,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_niose_full" \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_full" \ --trainable_models "dit" \ --extra_inputs "input_image,camera_control_direction,camera_control_speed" \ --max_timestep_boundary 0.358 \ diff --git a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh index 2cc7da0..6f5ac87 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh @@ -9,7 +9,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.2-Fun-A14B-Control_high_niose_full" \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control_high_noise_full" \ --trainable_models "dit" \ --extra_inputs "control_video,reference_image" \ --max_timestep_boundary 0.358 \ diff --git a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh index 225b888..7c623a0 100644 --- a/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh +++ b/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh @@ -8,7 +8,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.2-Fun-A14B-InP_high_niose_full" \ + --output_path "./models/train/Wan2.2-Fun-A14B-InP_high_noise_full" \ --trainable_models "dit" \ --extra_inputs "input_image,end_image" \ --max_timestep_boundary 0.358 \ diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh index b97e800..1a9983b 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh @@ -9,7 +9,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_niose_lora" \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control-Camera_high_noise_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh index f0af530..571ae54 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh @@ -9,7 +9,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.2-Fun-A14B-Control_high_niose_lora" \ + --output_path "./models/train/Wan2.2-Fun-A14B-Control_high_noise_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh index 94cf196..491351c 100644 --- a/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh +++ b/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh @@ -8,7 +8,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/Wan2.2-Fun-A14B-InP_high_niose_lora" \ + --output_path "./models/train/Wan2.2-Fun-A14B-InP_high_noise_lora" \ --lora_base_model "dit" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py index 9e914ff..0441a54 100644 --- a/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py @@ -1,7 +1,8 @@ import torch from PIL import Image import librosa -from diffsynth import VideoData, save_video_with_audio, load_state_dict +from diffsynth.utils.data import VideoData, save_video_with_audio +from diffsynth.core import load_state_dict from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py index 353e19e..ea28432 100644 --- a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py @@ -16,7 +16,7 @@ pipe = WanVideoPipeline.from_pretrained( ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), ], ) -pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_lora/epoch-4.safetensors", alpha=1) video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py index d936ca3..0a82de6 100644 --- a/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py @@ -1,7 +1,7 @@ import torch from PIL import Image import librosa -from diffsynth import VideoData, save_video_with_audio +from diffsynth.utils.data import VideoData, save_video_with_audio from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig pipe = WanVideoPipeline.from_pretrained(