mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
tmp commit
This commit is contained in:
@@ -69,6 +69,7 @@ from ..models.flux_value_control import SingleValueEncoder
|
|||||||
from ..lora.flux_lora import FluxLoraPatcher
|
from ..lora.flux_lora import FluxLoraPatcher
|
||||||
from ..models.flux_lora_encoder import FluxLoRAEncoder
|
from ..models.flux_lora_encoder import FluxLoRAEncoder
|
||||||
|
|
||||||
|
from ..models.nexus_gen_projector import NexusGenAdapter
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -152,6 +153,7 @@ model_loader_configs = [
|
|||||||
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
|
||||||
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
(None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
|
||||||
(None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
|
(None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
|
||||||
|
(None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus-gen_adapter"], [FluxDiT, NexusGenAdapter], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .tiler import TileWorker
|
from .tiler import TileWorker
|
||||||
from .utils import init_weights_on_device
|
from .utils import init_weights_on_device, hash_state_dict_keys
|
||||||
|
|
||||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||||
@@ -662,6 +662,9 @@ class FluxDiTStateDictConverter:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
|
if hash_state_dict_keys(state_dict, with_shape=True) == "3e6c61b0f9471135fc9c6d6a98e98b6d":
|
||||||
|
dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if not key.startswith('adapter.')}
|
||||||
|
return dit_state_dict
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||||
|
|||||||
21
examples/flux/model_inference/Nexus-Gen-Generation.py
Normal file
21
examples/flux/model_inference/Nexus-Gen-Generation.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||||
|
|
||||||
|
if importlib.util.find_spec("transformers") is None:
|
||||||
|
raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.")
|
||||||
|
else:
|
||||||
|
import transformers
|
||||||
|
assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==0.49.0, please install it with `pip install transformers==0.49.0`."
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user