diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9d49b44..b4ec524 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -33,7 +33,7 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan from ..models.hunyuan_dit import HunyuanDiT from ..models.flux_dit import FluxDiT -from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 +from ..models.flux_text_encoder import FluxTextEncoder2 from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_controlnet import FluxControlNet @@ -75,7 +75,7 @@ model_loader_configs = [ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"), (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"), (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"), - (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"), + (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"), (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"), (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"), (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"), @@ -89,7 +89,6 @@ model_loader_configs = [ (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), - (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "civitai"), (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), # (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai") ] @@ -551,14 +550,20 @@ preset_models_on_modelscope = { ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"), ], # Omnigen - "OmniGen-v1": [ - ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1"), - ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"), - ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"), - ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"), - ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"), - ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"), - ], + "OmniGen-v1": { + "file_list": [ + ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"), + ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"), + ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"), + ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"), + ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"), + ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"), + ], + "load_path": [ + "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors", + "models/OmniGen/OmniGen-v1/model.safetensors", + ] + }, # CogVideo "CogVideoX-5B": { "file_list": [ diff --git a/diffsynth/models/flux_text_encoder.py b/diffsynth/models/flux_text_encoder.py index 9e04519..bff9d29 100644 --- a/diffsynth/models/flux_text_encoder.py +++ b/diffsynth/models/flux_text_encoder.py @@ -3,26 +3,6 @@ from transformers import T5EncoderModel, T5Config from .sd_text_encoder import SDTextEncoder -class FluxTextEncoder1(SDTextEncoder): - def __init__(self, vocab_size=49408): - super().__init__(vocab_size=vocab_size) - - def forward(self, input_ids, clip_skip=2): - embeds = self.token_embedding(input_ids) + self.position_embeds - attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) - for encoder_id, encoder in enumerate(self.encoders): - embeds = encoder(embeds, attn_mask=attn_mask) - if encoder_id + clip_skip == len(self.encoders): - hidden_states = embeds - embeds = self.final_layer_norm(embeds) - pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] - return embeds, pooled_embeds - - @staticmethod - def state_dict_converter(): - return FluxTextEncoder1StateDictConverter() - - class FluxTextEncoder2(T5EncoderModel): def __init__(self, config): @@ -40,47 +20,6 @@ class FluxTextEncoder2(T5EncoderModel): -class FluxTextEncoder1StateDictConverter: - def __init__(self): - pass - - def from_diffusers(self, state_dict): - rename_dict = { - "text_model.embeddings.token_embedding.weight": "token_embedding.weight", - "text_model.embeddings.position_embedding.weight": "position_embeds", - "text_model.final_layer_norm.weight": "final_layer_norm.weight", - "text_model.final_layer_norm.bias": "final_layer_norm.bias" - } - attn_rename_dict = { - "self_attn.q_proj": "attn.to_q", - "self_attn.k_proj": "attn.to_k", - "self_attn.v_proj": "attn.to_v", - "self_attn.out_proj": "attn.to_out", - "layer_norm1": "layer_norm1", - "layer_norm2": "layer_norm2", - "mlp.fc1": "fc1", - "mlp.fc2": "fc2", - } - state_dict_ = {} - for name in state_dict: - if name in rename_dict: - param = state_dict[name] - if name == "text_model.embeddings.position_embedding.weight": - param = param.reshape((1, param.shape[0], param.shape[1])) - state_dict_[rename_dict[name]] = param - elif name.startswith("text_model.encoder.layers."): - param = state_dict[name] - names = name.split(".") - layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] - name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) - state_dict_[name_] = param - return state_dict_ - - def from_civitai(self, state_dict): - return self.from_diffusers(state_dict) - - - class FluxTextEncoder2StateDictConverter(): def __init__(self): pass diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index c68bdde..50b4a92 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -37,7 +37,7 @@ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5Tex from .hunyuan_dit import HunyuanDiT from .flux_dit import FluxDiT -from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 +from .flux_text_encoder import FluxTextEncoder2 from .flux_vae import FluxVAEEncoder, FluxVAEDecoder from .cog_vae import CogVAEEncoder, CogVAEDecoder diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index d76eadb..fb8813f 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -36,7 +36,7 @@ class BasePipeline(torch.nn.Module): return video - def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0): + def merge_latents(self, value, latents, masks, scales, blur_kernel_size=3, blur_sigma=1.0): blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) height, width = value.shape[-2:] weight = torch.ones_like(value) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 69664b3..95ccadc 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -1,4 +1,4 @@ -from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder +from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompters import FluxPrompter from ..schedulers import FlowMatchScheduler @@ -19,7 +19,7 @@ class FluxImagePipeline(BasePipeline): self.scheduler = FlowMatchScheduler() self.prompter = FluxPrompter() # models - self.text_encoder_1: FluxTextEncoder1 = None + self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_2: FluxTextEncoder2 = None self.dit: FluxDiT = None self.vae_decoder: FluxVAEDecoder = None @@ -33,7 +33,7 @@ class FluxImagePipeline(BasePipeline): def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]): - self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1") + self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1") self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") self.dit = model_manager.fetch_model("flux_dit") self.vae_decoder = model_manager.fetch_model("flux_vae_decoder") diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py index 9a6bd7d..a3a06ff 100644 --- a/diffsynth/prompters/flux_prompter.py +++ b/diffsynth/prompters/flux_prompter.py @@ -1,5 +1,6 @@ from .base_prompter import BasePrompter -from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 +from ..models.flux_text_encoder import FluxTextEncoder2 +from ..models.sd3_text_encoder import SD3TextEncoder1 from transformers import CLIPTokenizer, T5TokenizerFast import os, torch @@ -19,11 +20,11 @@ class FluxPrompter(BasePrompter): super().__init__() self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path) self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path) - self.text_encoder_1: FluxTextEncoder1 = None + self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_2: FluxTextEncoder2 = None - def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None): + def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None): self.text_encoder_1 = text_encoder_1 self.text_encoder_2 = text_encoder_2 @@ -36,7 +37,7 @@ class FluxPrompter(BasePrompter): max_length=max_length, truncation=True ).input_ids.to(device) - _, pooled_prompt_emb = text_encoder(input_ids) + pooled_prompt_emb, _ = text_encoder(input_ids) return pooled_prompt_emb