diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b60c200..9fa652d 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -69,6 +69,7 @@ from ..models.flux_value_control import SingleValueEncoder from ..lora.flux_lora import FluxLoraPatcher from ..models.flux_lora_encoder import FluxLoRAEncoder +from ..models.nexus_gen_projector import NexusGenAdapter model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -152,6 +153,7 @@ model_loader_configs = [ (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), (None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"), + (None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus-gen_adapter"], [FluxDiT, NexusGenAdapter], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index ea5ce21..3dd728d 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -2,7 +2,7 @@ import torch from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm from einops import rearrange from .tiler import TileWorker -from .utils import init_weights_on_device +from .utils import init_weights_on_device, hash_state_dict_keys def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): batch_size, num_tokens = hidden_states.shape[0:2] @@ -662,6 +662,9 @@ class FluxDiTStateDictConverter: return state_dict_ def from_civitai(self, state_dict): + if hash_state_dict_keys(state_dict, with_shape=True) == "3e6c61b0f9471135fc9c6d6a98e98b6d": + dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if not key.startswith('adapter.')} + return dit_state_dict rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", diff --git a/examples/flux/model_inference/Nexus-Gen-Generation.py b/examples/flux/model_inference/Nexus-Gen-Generation.py new file mode 100644 index 0000000..102b7ef --- /dev/null +++ b/examples/flux/model_inference/Nexus-Gen-Generation.py @@ -0,0 +1,21 @@ +import importlib +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==0.49.0, please install it with `pip install transformers==0.49.0`." + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +)