mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
Add: FLUX
This commit is contained in:
@@ -1,4 +1,41 @@
|
||||
import torch
|
||||
import hashlib
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
all_keys = sorted(list(state_dict))
|
||||
|
||||
for key in all_keys:
|
||||
value = state_dict[key]
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
def FluxDiTStateDictConverter(state_dict):
|
||||
model_hash = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
|
||||
if model_hash in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
|
||||
dit_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith('pipe.dit.'):
|
||||
value = state_dict[key]
|
||||
new_key = key.replace("pipe.dit.", "")
|
||||
dit_state_dict[new_key] = value
|
||||
return dit_state_dict
|
||||
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
|
||||
34
diffsynth/utils/state_dict_converters/flux_ipadapter.py
Normal file
34
diffsynth/utils/state_dict_converters/flux_ipadapter.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
def FluxIpAdapterStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
|
||||
if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict):
|
||||
for name, param in state_dict["ip_adapter"].items():
|
||||
name_ = 'ipadapter_modules.' + name
|
||||
state_dict_[name_] = param
|
||||
|
||||
if "image_proj" in state_dict:
|
||||
for name, param in state_dict["image_proj"].items():
|
||||
name_ = "image_proj." + name
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict_[key] = value
|
||||
elif key.startswith("ip_adapter."):
|
||||
new_key = key.replace("ip_adapter.", "ipadapter_modules.")
|
||||
state_dict_[new_key] = value
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
|
||||
|
||||
def SiglipStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("vision_model."):
|
||||
new_state_dict[key] = state_dict[key]
|
||||
return new_state_dict
|
||||
8
diffsynth/utils/state_dict_converters/nexus_gen.py
Normal file
8
diffsynth/utils/state_dict_converters/nexus_gen.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
def NexusGenAutoregressiveModelStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
value = state_dict[key]
|
||||
new_state_dict["model." + key] = value
|
||||
return new_state_dict
|
||||
17
diffsynth/utils/state_dict_converters/nexus_gen_projector.py
Normal file
17
diffsynth/utils/state_dict_converters/nexus_gen_projector.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
def NexusGenMergerStateDictConverter(state_dict):
|
||||
merger_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith('embedding_merger.'):
|
||||
value = state_dict[key]
|
||||
new_key = key.replace("embedding_merger.", "")
|
||||
merger_state_dict[new_key] = value
|
||||
return merger_state_dict
|
||||
|
||||
def NexusGenAdapterStateDictConverter(state_dict):
|
||||
adapter_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith('adapter.'):
|
||||
adapter_state_dict[key] = state_dict[key]
|
||||
return adapter_state_dict
|
||||
Reference in New Issue
Block a user