This commit is contained in:
Artiprocher
2025-11-19 20:22:21 +08:00
parent 6ad8d73717
commit eeb55a0ce6
88 changed files with 3113 additions and 78 deletions

View File

@@ -1,6 +1,5 @@
import torch
import hashlib
import json
def FluxControlNetStateDictConverter(state_dict):
global_rename_dict = {

View File

@@ -1,39 +1,17 @@
import torch
import hashlib
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
all_keys = sorted(list(state_dict))
for key in all_keys:
value = state_dict[key]
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def FluxDiTStateDictConverter(state_dict):
model_hash = hash_state_dict_keys(state_dict, with_shape=True)
if model_hash in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
is_nexus_gen = sum([key.startswith("pipe.dit.") for key in state_dict]) > 0
if is_nexus_gen:
dit_state_dict = {}
for key in state_dict:
if key.startswith('pipe.dit.'):
value = state_dict[key]
param = state_dict[key]
new_key = key.replace("pipe.dit.", "")
dit_state_dict[new_key] = value
if new_key.startswith("final_norm_out.linear."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
dit_state_dict[new_key] = param
return dit_state_dict
rename_dict = {

View File

@@ -1,4 +1,2 @@
import torch
def FluxInfiniteYouImageProjectorStateDictConverter(state_dict):
return state_dict['image_proj']

View File

@@ -1,5 +1,3 @@
import torch
def FluxIpAdapterStateDictConverter(state_dict):
state_dict_ = {}

View File

@@ -1,5 +1,3 @@
import torch
def NexusGenAutoregressiveModelStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:

View File

@@ -1,5 +1,3 @@
import torch
def NexusGenMergerStateDictConverter(state_dict):
merger_state_dict = {}
for key in state_dict:

View File

@@ -0,0 +1,7 @@
def Qwen2ConnectorStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("connector."):
name_ = name[len("connector."):]
state_dict_[name_] = state_dict[name]
return state_dict_