mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
Made much much faster than before
enable debug to see every message
This commit is contained in:
@@ -1,4 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import multiprocessing
|
||||||
from .sd_unet import SDUNet
|
from .sd_unet import SDUNet
|
||||||
from .sdxl_unet import SDXLUNet
|
from .sdxl_unet import SDXLUNet
|
||||||
from .sd_text_encoder import SDTextEncoder
|
from .sd_text_encoder import SDTextEncoder
|
||||||
@@ -10,7 +17,67 @@ from .cog_dit import CogDiT
|
|||||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||||
from .wan_video_dit import WanModel
|
from .wan_video_dit import WanModel
|
||||||
|
|
||||||
|
# Global debug variable: when set to False, only minimal info is printed.
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
def debug_print(*args, **kwargs):
|
||||||
|
"""Print debug messages only if DEBUG is True."""
|
||||||
|
if DEBUG:
|
||||||
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
def timing_decorator(func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
if DEBUG:
|
||||||
|
print(f"⏱️ {func.__name__} took {elapsed_time:.4f} seconds")
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def memory_usage():
|
||||||
|
"""Get current memory usage of the process"""
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
memory_info = process.memory_info()
|
||||||
|
return f"{memory_info.rss / (1024 * 1024):.1f} MB"
|
||||||
|
|
||||||
|
def optimize_cpu_threading():
|
||||||
|
"""Set optimal thread configuration for the current CPU"""
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
|
||||||
|
# Get processor information
|
||||||
|
processor = platform.processor().lower()
|
||||||
|
|
||||||
|
if "amd" in processor:
|
||||||
|
optimal_threads = max(1, cpu_count)
|
||||||
|
else: # Intel or other
|
||||||
|
optimal_threads = max(1, cpu_count // 2)
|
||||||
|
|
||||||
|
os.environ["OMP_NUM_THREADS"] = str(optimal_threads)
|
||||||
|
os.environ["MKL_NUM_THREADS"] = str(optimal_threads)
|
||||||
|
|
||||||
|
blas_info = "unknown"
|
||||||
|
try:
|
||||||
|
import torch.__config__
|
||||||
|
config_info = torch.__config__.show()
|
||||||
|
if "mkl" in config_info.lower():
|
||||||
|
blas_info = "MKL"
|
||||||
|
elif "openblas" in config_info.lower():
|
||||||
|
blas_info = "OpenBLAS"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print(f"CPU Optimization: {platform.processor()}")
|
||||||
|
print(f"Physical cores: {cpu_count // 2}, Total threads: {cpu_count}")
|
||||||
|
print(f"Using {optimal_threads} threads for computation")
|
||||||
|
print(f"BLAS backend: {blas_info}")
|
||||||
|
else:
|
||||||
|
print(f"CPU threading optimized: using {optimal_threads} threads")
|
||||||
|
|
||||||
|
torch.set_num_threads(optimal_threads)
|
||||||
|
return optimal_threads
|
||||||
|
|
||||||
class LoRAFromCivitai:
|
class LoRAFromCivitai:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -18,110 +85,327 @@ class LoRAFromCivitai:
|
|||||||
self.lora_prefix = []
|
self.lora_prefix = []
|
||||||
self.renamed_lora_prefix = {}
|
self.renamed_lora_prefix = {}
|
||||||
self.special_keys = {}
|
self.special_keys = {}
|
||||||
|
self.stats = {
|
||||||
|
"tensor_movements_to_gpu": 0,
|
||||||
|
"tensor_movements_to_cpu": 0,
|
||||||
|
"lora_weights_processed": 0,
|
||||||
|
"format_conversions": 0,
|
||||||
|
}
|
||||||
|
# Set optimal thread count for CPU operations
|
||||||
|
self.optimal_threads = optimize_cpu_threading()
|
||||||
|
self.use_gpu = torch.cuda.is_available()
|
||||||
|
|
||||||
|
# Enable tensor cores for matrix operations if available
|
||||||
|
if self.use_gpu and hasattr(torch.backends, 'cudnn'):
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Converting state dict with prefix {lora_prefix}, memory usage: {memory_usage()}")
|
||||||
|
# Detect format
|
||||||
for key in state_dict:
|
for key in state_dict:
|
||||||
if ".lora_up" in key:
|
if ".lora_up" in key:
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Detected up/down format, keys: {len(state_dict)}")
|
||||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Detected A/B format, keys: {len(state_dict)}")
|
||||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Processing up/down conversion for {len(state_dict)} tensors...")
|
||||||
|
|
||||||
|
# Determine optimal processing device
|
||||||
|
device = "cuda" if self.use_gpu else "cpu"
|
||||||
|
torch_dtype = torch.float16 if self.use_gpu else torch.float32
|
||||||
|
|
||||||
|
# Count applicable keys first
|
||||||
|
applicable_keys = []
|
||||||
for key in state_dict:
|
for key in state_dict:
|
||||||
if ".lora_up" not in key:
|
if ".lora_up" in key and key.startswith(lora_prefix):
|
||||||
continue
|
applicable_keys.append(key)
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
# Prepare batches for processing
|
||||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
BATCH_SIZE = 16 # Adjust based on memory constraints
|
||||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
if DEBUG:
|
||||||
if len(weight_up.shape) == 4:
|
print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...")
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
with tqdm(total=len(applicable_keys), desc="Converting up/down weights") as pbar:
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
for i in range(0, len(applicable_keys), BATCH_SIZE):
|
||||||
else:
|
batch_keys = applicable_keys[i:i+BATCH_SIZE]
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
for key in batch_keys:
|
||||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
# Track GPU tensor movements
|
||||||
for special_key in self.special_keys:
|
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device=device, dtype=torch_dtype)
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
self.stats["tensor_movements_to_gpu"] += 2
|
||||||
|
|
||||||
|
# Matrix multiplication - faster on GPU, or optimized CPU
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
|
||||||
|
target_key = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||||
|
state_dict_[target_key] = lora_weight.cpu()
|
||||||
|
self.stats["tensor_movements_to_cpu"] += 1
|
||||||
|
self.stats["lora_weights_processed"] += 1
|
||||||
|
|
||||||
|
# Apply special key replacements
|
||||||
|
for special_key in self.special_keys:
|
||||||
|
if special_key in target_key:
|
||||||
|
state_dict_[target_key] = state_dict_[target_key].replace(special_key, self.special_keys[special_key])
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Clear memory after each batch
|
||||||
|
del weight_up, weight_down, lora_weight
|
||||||
|
if self.use_gpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Up/down conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}")
|
||||||
|
else:
|
||||||
|
print(f"LoRA conversion complete: {len(state_dict_)} tensors processed")
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0):
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
|
# Determine optimal processing device
|
||||||
|
device = "cuda" if self.use_gpu else "cpu"
|
||||||
|
torch_dtype = torch.float16 if self.use_gpu else torch.float32
|
||||||
|
|
||||||
|
# Collect applicable keys first
|
||||||
|
applicable_keys = []
|
||||||
for key in state_dict:
|
for key in state_dict:
|
||||||
if ".lora_B." not in key:
|
if ".lora_B." in key and key.startswith(lora_prefix):
|
||||||
continue
|
applicable_keys.append(key)
|
||||||
if not key.startswith(lora_prefix):
|
|
||||||
continue
|
# Prepare batches for processing
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
BATCH_SIZE = 16 # Adjust based on memory constraints
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
if DEBUG:
|
||||||
if len(weight_up.shape) == 4:
|
print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...")
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
with tqdm(total=len(applicable_keys), desc="Converting A/B weights") as pbar:
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
for i in range(0, len(applicable_keys), BATCH_SIZE):
|
||||||
else:
|
batch_keys = applicable_keys[i:i+BATCH_SIZE]
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
for key in batch_keys:
|
||||||
keys = key.split(".")
|
# Load and process tensors
|
||||||
keys.pop(keys.index("lora_B"))
|
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||||
target_name = ".".join(keys)
|
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||||
target_name = target_name[len(lora_prefix):]
|
self.stats["tensor_movements_to_gpu"] += 2
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
|
||||||
|
# Extract target name
|
||||||
|
keys = key.split(".")
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
target_name = target_name[len(lora_prefix):]
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
self.stats["tensor_movements_to_cpu"] += 1
|
||||||
|
self.stats["lora_weights_processed"] += 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Clear memory after each batch
|
||||||
|
del weight_up, weight_down, lora_weight
|
||||||
|
if self.use_gpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print(f"A/B conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}")
|
||||||
|
else:
|
||||||
|
print(f"LoRA conversion complete: {len(state_dict_)} tensors processed")
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||||
state_dict_model = model.state_dict()
|
print(f"Starting LoRA loading process for {model.__class__.__name__}...")
|
||||||
|
|
||||||
|
# Measure state dict loading time - use direct parameter access
|
||||||
|
start_state_dict = time.time()
|
||||||
|
state_dict_model = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
state_dict_model[name] = param
|
||||||
|
end_state_dict = time.time()
|
||||||
|
if DEBUG:
|
||||||
|
print(f"⏱️ Loading model parameters took {end_state_dict - start_state_dict:.4f} seconds, size: {len(state_dict_model)} tensors")
|
||||||
|
else:
|
||||||
|
print(f"Model parameters mapped: {len(state_dict_model)} parameters")
|
||||||
|
|
||||||
|
# Measure LoRA conversion time
|
||||||
|
start_convert = time.time()
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||||
if model_resource == "diffusers":
|
self.stats["format_conversions"] += 1
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
end_convert = time.time()
|
||||||
elif model_resource == "civitai":
|
if DEBUG:
|
||||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
print(f"⏱️ LoRA conversion took {end_convert - start_convert:.4f} seconds")
|
||||||
|
|
||||||
|
# Measure format conversion time if applicable
|
||||||
|
if model_resource:
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Converting format from {model_resource}...")
|
||||||
|
start_format = time.time()
|
||||||
|
if model_resource == "diffusers":
|
||||||
|
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||||
|
elif model_resource == "civitai":
|
||||||
|
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||||
|
self.stats["format_conversions"] += 1
|
||||||
|
end_format = time.time()
|
||||||
|
if DEBUG:
|
||||||
|
print(f"⏱️ Format conversion took {end_format - start_format:.4f} seconds")
|
||||||
|
|
||||||
if isinstance(state_dict_lora, tuple):
|
if isinstance(state_dict_lora, tuple):
|
||||||
state_dict_lora = state_dict_lora[0]
|
state_dict_lora = state_dict_lora[0]
|
||||||
|
|
||||||
if len(state_dict_lora) > 0:
|
if len(state_dict_lora) > 0:
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
if DEBUG:
|
||||||
for name in state_dict_lora:
|
print(f"Applying {len(state_dict_lora)} LoRA tensors to model weights...")
|
||||||
fp8=False
|
else:
|
||||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
print("Applying LoRA weights...")
|
||||||
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
|
||||||
fp8=True
|
# Process in batches
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
BATCH_SIZE = 32
|
||||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
lora_keys = list(state_dict_lora.keys())
|
||||||
if fp8:
|
|
||||||
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
start_update = time.time()
|
||||||
model.load_state_dict(state_dict_model)
|
with tqdm(total=len(lora_keys), desc="Applying LoRA weights") as pbar:
|
||||||
|
for i in range(0, len(lora_keys), BATCH_SIZE):
|
||||||
|
batch_keys = lora_keys[i:i+BATCH_SIZE]
|
||||||
|
for name in batch_keys:
|
||||||
|
if name not in state_dict_model:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict_model[name]
|
||||||
|
|
||||||
|
# Handle FP8 tensors
|
||||||
|
fp8 = False
|
||||||
|
if param.dtype == torch.float8_e4m3fn:
|
||||||
|
param_data = param.to(state_dict_lora[name].dtype)
|
||||||
|
fp8 = True
|
||||||
|
else:
|
||||||
|
param_data = param.data
|
||||||
|
|
||||||
|
# Apply direct update (avoids load_state_dict overhead)
|
||||||
|
param.data = param_data + state_dict_lora[name].to(
|
||||||
|
dtype=param_data.dtype, device=param_data.device)
|
||||||
|
|
||||||
|
if fp8:
|
||||||
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Clear memory after each batch
|
||||||
|
if self.use_gpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
end_update = time.time()
|
||||||
|
if DEBUG:
|
||||||
|
print(f"⏱️ Weight update took {end_update - start_update:.4f} seconds")
|
||||||
|
else:
|
||||||
|
print("Weight update complete.")
|
||||||
|
else:
|
||||||
|
print("No LoRA tensors to apply!")
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print("\n==== LoRA LOADING STATISTICS ====")
|
||||||
|
print(f"Total tensor movements to GPU: {self.stats['tensor_movements_to_gpu']}")
|
||||||
|
print(f"Total tensor movements to CPU: {self.stats['tensor_movements_to_cpu']}")
|
||||||
|
print(f"Total LoRA weights processed: {self.stats['lora_weights_processed']}")
|
||||||
|
print(f"Total format conversions: {self.stats['format_conversions']}")
|
||||||
|
print(f"Final memory usage: {memory_usage()}")
|
||||||
|
print("================================")
|
||||||
|
else:
|
||||||
|
print(f"LoRA load complete: {self.stats['lora_weights_processed']} weights processed, GPU moves: {self.stats['tensor_movements_to_gpu']}, CPU moves: {self.stats['tensor_movements_to_cpu']}.")
|
||||||
|
|
||||||
|
# Clear temporary data and run garbage collection
|
||||||
|
del state_dict_lora
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def match(self, model, state_dict_lora):
|
def match(self, model, state_dict_lora):
|
||||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
if DEBUG:
|
||||||
|
print(f"Trying to match LoRA format for {model.__class__.__name__}, memory usage: {memory_usage()}")
|
||||||
|
match_results = []
|
||||||
|
|
||||||
|
for i, (lora_prefix, model_class) in enumerate(zip(self.lora_prefix, self.supported_model_classes)):
|
||||||
if not isinstance(model, model_class):
|
if not isinstance(model, model_class):
|
||||||
continue
|
continue
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Checking prefix '{lora_prefix}' for model class {model_class.__name__}")
|
||||||
|
|
||||||
|
# Get parameter names
|
||||||
|
param_names = set(name for name, _ in model.named_parameters())
|
||||||
|
|
||||||
for model_resource in ["diffusers", "civitai"]:
|
for model_resource in ["diffusers", "civitai"]:
|
||||||
try:
|
try:
|
||||||
|
if DEBUG:
|
||||||
|
print(f" Attempting {model_resource} format...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Try conversion
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||||
else model.__class__.state_dict_converter().from_civitai
|
else model.__class__.state_dict_converter().from_civitai
|
||||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||||
|
|
||||||
if isinstance(state_dict_lora_, tuple):
|
if isinstance(state_dict_lora_, tuple):
|
||||||
state_dict_lora_ = state_dict_lora_[0]
|
state_dict_lora_ = state_dict_lora_[0]
|
||||||
|
|
||||||
if len(state_dict_lora_) == 0:
|
if len(state_dict_lora_) == 0:
|
||||||
|
if DEBUG:
|
||||||
|
print(f" ❌ No matching tensors found for {model_resource} format")
|
||||||
continue
|
continue
|
||||||
for name in state_dict_lora_:
|
|
||||||
if name not in state_dict_model:
|
# Verify the keys actually match the model (sample check)
|
||||||
|
valid_keys = 0
|
||||||
|
for name in list(state_dict_lora_.keys())[:10]:
|
||||||
|
if name in param_names:
|
||||||
|
valid_keys += 1
|
||||||
|
else:
|
||||||
|
if DEBUG:
|
||||||
|
print(f" ⚠️ Key not found in model: {name}")
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
if valid_keys > 0:
|
||||||
|
if DEBUG:
|
||||||
|
print(f" ✅ Match found! Prefix: {lora_prefix}, Format: {model_resource}, Valid keys: {valid_keys}")
|
||||||
|
print(f" ⏱️ Match verification took {end_time - start_time:.4f} seconds")
|
||||||
|
else:
|
||||||
|
print("Matching format found.")
|
||||||
return lora_prefix, model_resource
|
return lora_prefix, model_resource
|
||||||
except:
|
else:
|
||||||
pass
|
if DEBUG:
|
||||||
|
print(f" ❌ No valid keys found for this format")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if DEBUG:
|
||||||
|
print(f" ❌ Error during matching: {str(e)}")
|
||||||
|
if DEBUG:
|
||||||
|
print("❌ No match found for any format or prefix")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Specialized classes derived from LoRAFromCivitai
|
||||||
|
|
||||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -148,7 +432,6 @@ class SDLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
"output.blocks": "model.diffusion_model.output_blocks",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -176,7 +459,6 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
"output.blocks": "model.diffusion_model.output_blocks",
|
"output.blocks": "model.diffusion_model.output_blocks",
|
||||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -195,14 +477,29 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
"txt.mod": "txt_mod",
|
"txt.mod": "txt_mod",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||||
|
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||||
|
self.special_keys = {}
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
class GeneralLoRAFromPeft:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||||
|
self.stats = {
|
||||||
def fetch_device_dtype_from_state_dict(self, target_param):
|
"tensor_movements_to_gpu": 0,
|
||||||
"""Get device and dtype from a parameter"""
|
"tensor_movements_to_cpu": 0,
|
||||||
return target_param.device, target_param.dtype
|
"lora_weights_processed": 0,
|
||||||
|
"parameter_updates": 0,
|
||||||
|
}
|
||||||
|
# Set optimal thread count for CPU operations
|
||||||
|
self.optimal_threads = optimize_cpu_threading()
|
||||||
|
self.use_gpu = torch.cuda.is_available()
|
||||||
|
|
||||||
|
# Enable tensor cores for matrix operations if available
|
||||||
|
if self.use_gpu and hasattr(torch.backends, 'cudnn'):
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
def _get_target_name(self, key):
|
def _get_target_name(self, key):
|
||||||
"""Extract target parameter name from LoRA key"""
|
"""Extract target parameter name from LoRA key"""
|
||||||
@@ -215,105 +512,170 @@ class GeneralLoRAFromPeft:
|
|||||||
target_name = target_name[len("diffusion_model."):]
|
target_name = target_name[len("diffusion_model."):]
|
||||||
return target_name
|
return target_name
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||||
"""Original method kept for compatibility with match method"""
|
if DEBUG:
|
||||||
device, torch_dtype = None, None
|
print(f"Converting state dict with GeneralLoRAFromPeft, memory: {memory_usage()}")
|
||||||
for name, param in target_state_dict.items():
|
device = "cuda" if self.use_gpu else "cpu"
|
||||||
device, torch_dtype = param.device, param.dtype
|
torch_dtype = torch.float16 if self.use_gpu else torch.float32
|
||||||
break
|
|
||||||
|
|
||||||
if torch_dtype == torch.float8_e4m3fn:
|
|
||||||
torch_dtype = torch.float32
|
|
||||||
|
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
# Count applicable keys
|
||||||
continue
|
applicable_keys = [key for key in state_dict if ".lora_B." in key]
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...")
|
||||||
|
|
||||||
|
with tqdm(total=len(applicable_keys), desc="Converting LoRA weights") as pbar:
|
||||||
|
for i in range(0, len(applicable_keys), BATCH_SIZE):
|
||||||
|
batch_keys = applicable_keys[i:i+BATCH_SIZE]
|
||||||
|
for key in batch_keys:
|
||||||
|
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||||
|
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||||
|
self.stats["tensor_movements_to_gpu"] += 2
|
||||||
|
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
|
||||||
|
target_name = self._get_target_name(key)
|
||||||
|
|
||||||
|
if target_state_dict and target_name not in target_state_dict:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_dict_[target_name] = lora_weight.cpu()
|
||||||
|
self.stats["tensor_movements_to_cpu"] += 1
|
||||||
|
self.stats["lora_weights_processed"] += 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
# Clear memory after each batch
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
del weight_up, weight_down, lora_weight
|
||||||
|
if self.use_gpu:
|
||||||
if len(weight_up.shape) == 4:
|
torch.cuda.empty_cache()
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
|
|
||||||
target_name = self._get_target_name(key)
|
if DEBUG:
|
||||||
|
print(f"Conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}")
|
||||||
if target_name not in target_state_dict:
|
else:
|
||||||
return {}
|
print(f"General LoRA conversion complete: {len(state_dict_)} weights processed")
|
||||||
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
|
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||||
"""Apply LoRA weights directly to model parameters without loading entire state dict"""
|
"""Apply LoRA weights directly to model parameters with batched processing"""
|
||||||
# Create parameter name mapping for faster lookup
|
print(f"Starting optimized LoRA loading for {model.__class__.__name__}...")
|
||||||
param_dict = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
param_dict[name] = param
|
|
||||||
|
|
||||||
# Process each LoRA parameter pair
|
# Create parameter lookup dict
|
||||||
modified_count = 0
|
start_map = time.time()
|
||||||
for key in state_dict_lora:
|
param_dict = {name: param for name, param in model.named_parameters()}
|
||||||
if ".lora_B." not in key:
|
end_map = time.time()
|
||||||
continue
|
if DEBUG:
|
||||||
|
print(f"⏱️ Parameter mapping took {end_map - start_map:.4f} seconds, found {len(param_dict)} parameters")
|
||||||
# Get target parameter name and make sure the parameter exists
|
else:
|
||||||
|
print(f"Mapped {len(param_dict)} model parameters")
|
||||||
|
|
||||||
|
# Count applicable LoRA parameters
|
||||||
|
lora_b_keys = [key for key in state_dict_lora if ".lora_B." in key]
|
||||||
|
print(f"Found {len(lora_b_keys)} LoRA parameter pairs to process")
|
||||||
|
|
||||||
|
# Group parameters by shape for better memory access patterns
|
||||||
|
shape_groups = {}
|
||||||
|
for key in lora_b_keys:
|
||||||
target_name = self._get_target_name(key)
|
target_name = self._get_target_name(key)
|
||||||
if target_name not in param_dict:
|
if target_name not in param_dict:
|
||||||
continue
|
continue
|
||||||
|
shape = state_dict_lora[key].shape
|
||||||
# Get the target parameter
|
if shape not in shape_groups:
|
||||||
param = param_dict[target_name]
|
shape_groups[shape] = []
|
||||||
|
shape_groups[shape].append((key, target_name))
|
||||||
# Calculate LoRA weight update
|
|
||||||
device, dtype = param.device, param.dtype
|
|
||||||
dtype_for_calc = torch.float32 if dtype == torch.float8_e4m3fn else dtype
|
|
||||||
|
|
||||||
# Process weights and calculate LoRA update
|
|
||||||
weight_b = state_dict_lora[key].to(device=device, dtype=dtype_for_calc)
|
|
||||||
weight_a = state_dict_lora[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=dtype_for_calc)
|
|
||||||
|
|
||||||
if len(weight_b.shape) == 4:
|
|
||||||
weight_b = weight_b.squeeze(3).squeeze(2)
|
|
||||||
weight_a = weight_a.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_b, weight_a)
|
|
||||||
|
|
||||||
# Apply update to parameter
|
|
||||||
if dtype == torch.float8_e4m3fn:
|
|
||||||
param_float = param.to(torch.float32)
|
|
||||||
param.data = (param_float + lora_weight).to(dtype)
|
|
||||||
del param_float
|
|
||||||
else:
|
|
||||||
param.data += lora_weight
|
|
||||||
|
|
||||||
# Clean up temporary tensors
|
|
||||||
del weight_a, weight_b, lora_weight
|
|
||||||
modified_count += 1
|
|
||||||
|
|
||||||
print(f" {modified_count} tensors are updated.")
|
if DEBUG:
|
||||||
|
print(f"Organized into {len(shape_groups)} shape groups for efficient processing")
|
||||||
|
|
||||||
|
# Process each shape group in batches
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
modified_count = 0
|
||||||
|
|
||||||
|
for shape, key_pairs in shape_groups.items():
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Processing {len(key_pairs)} parameters with shape {shape}")
|
||||||
|
for i in range(0, len(key_pairs), BATCH_SIZE):
|
||||||
|
batch = key_pairs[i:i+BATCH_SIZE]
|
||||||
|
for lora_key, target_name in batch:
|
||||||
|
param = param_dict[target_name]
|
||||||
|
dtype_for_calc = torch.float32 if param.dtype == torch.float8_e4m3fn else param.dtype
|
||||||
|
|
||||||
|
# Load weights and compute LoRA update
|
||||||
|
weight_b = state_dict_lora[lora_key].to(device="cuda" if self.use_gpu else "cpu", dtype=dtype_for_calc)
|
||||||
|
weight_a = state_dict_lora[lora_key.replace(".lora_B.", ".lora_A.")].to(device="cuda" if self.use_gpu else "cpu", dtype=dtype_for_calc)
|
||||||
|
self.stats["tensor_movements_to_gpu"] += 2
|
||||||
|
|
||||||
|
if len(weight_b.shape) == 4:
|
||||||
|
weight_b = weight_b.squeeze(3).squeeze(2)
|
||||||
|
weight_a = weight_a.squeeze(3).squeeze(2)
|
||||||
|
lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
lora_weight = alpha * torch.mm(weight_b, weight_a)
|
||||||
|
|
||||||
|
# Apply update directly to parameter
|
||||||
|
if param.dtype == torch.float8_e4m3fn:
|
||||||
|
param_float = param.to(torch.float32)
|
||||||
|
param.data = (param_float + lora_weight).to(param.dtype)
|
||||||
|
del param_float
|
||||||
|
else:
|
||||||
|
param.data += lora_weight.to(dtype=param.dtype, device=param.device)
|
||||||
|
|
||||||
|
del weight_a, weight_b, lora_weight
|
||||||
|
self.stats["parameter_updates"] += 1
|
||||||
|
modified_count += 1
|
||||||
|
|
||||||
|
if self.use_gpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print("\n==== OPTIMIZED LORA LOADING STATISTICS ====")
|
||||||
|
print(f"Total tensor movements to GPU: {self.stats['tensor_movements_to_gpu']}")
|
||||||
|
print(f"Total LoRA weights processed: {self.stats['lora_weights_processed']}")
|
||||||
|
print(f"Total parameters updated: {self.stats['parameter_updates']}")
|
||||||
|
print(f"Final memory usage: {memory_usage()}")
|
||||||
|
print("==========================================")
|
||||||
|
else:
|
||||||
|
print(f"Optimized LoRA load complete: updated {modified_count} tensors")
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
print(f"⏱️ {modified_count} tensors were updated successfully")
|
||||||
|
|
||||||
|
@timing_decorator
|
||||||
def match(self, model, state_dict_lora):
|
def match(self, model, state_dict_lora):
|
||||||
"""Check if LoRA parameters match model parameters without loading full state dict"""
|
"""Check if LoRA parameters match model parameters"""
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Checking General LoRA compatibility for {model.__class__.__name__}...")
|
||||||
|
|
||||||
for model_class in self.supported_model_classes:
|
for model_class in self.supported_model_classes:
|
||||||
if not isinstance(model, model_class):
|
if not isinstance(model, model_class):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create set of parameter names
|
# Create set of parameter names
|
||||||
param_names = set()
|
start_param = time.time()
|
||||||
for name, _ in model.named_parameters():
|
param_names = set(name for name, _ in model.named_parameters())
|
||||||
param_names.add(name)
|
end_param = time.time()
|
||||||
|
if DEBUG:
|
||||||
|
print(f"⏱️ Parameter name collection took {end_param - start_param:.4f} seconds, found {len(param_names)} names")
|
||||||
|
|
||||||
# Check if a sample of LoRA keys map to model parameters
|
# Check if a sample of LoRA keys map to model parameters
|
||||||
matched_count = 0
|
matched_count = 0
|
||||||
checked_count = 0
|
checked_count = 0
|
||||||
|
|
||||||
|
start_check = time.time()
|
||||||
for key in state_dict_lora:
|
for key in state_dict_lora:
|
||||||
if ".lora_B." not in key:
|
if ".lora_B." not in key:
|
||||||
continue
|
continue
|
||||||
@@ -324,30 +686,34 @@ class GeneralLoRAFromPeft:
|
|||||||
|
|
||||||
checked_count += 1
|
checked_count += 1
|
||||||
if matched_count >= 5: # Found enough matches
|
if matched_count >= 5: # Found enough matches
|
||||||
return "", ""
|
break
|
||||||
if checked_count >= 50 and matched_count == 0: # Checked enough without matches
|
if checked_count >= 50 and matched_count == 0: # Checked enough without matches
|
||||||
break
|
break
|
||||||
|
end_check = time.time()
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print(f"⏱️ Match check took {end_check - start_check:.4f} seconds")
|
||||||
|
print(f"Matched {matched_count}/{checked_count} checked parameters")
|
||||||
|
|
||||||
if matched_count > 0:
|
if matched_count > 0:
|
||||||
|
if DEBUG:
|
||||||
|
print(f"✅ Compatible with GeneralLoRAFromPeft")
|
||||||
|
else:
|
||||||
|
print("LoRA compatibility check: PASS")
|
||||||
return "", ""
|
return "", ""
|
||||||
|
if DEBUG:
|
||||||
|
print("❌ Not compatible with GeneralLoRAFromPeft")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
|
||||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
|
||||||
self.special_keys = {}
|
|
||||||
|
|
||||||
|
|
||||||
class FluxLoRAConverter:
|
class FluxLoRAConverter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@timing_decorator
|
||||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Converting Flux LoRA to opensource format, input keys: {len(state_dict)}")
|
||||||
prefix_rename_dict = {
|
prefix_rename_dict = {
|
||||||
"single_blocks": "lora_unet_single_blocks",
|
"single_blocks": "lora_unet_single_blocks",
|
||||||
"blocks": "lora_unet_double_blocks",
|
"blocks": "lora_unet_double_blocks",
|
||||||
@@ -356,7 +722,6 @@ class FluxLoRAConverter:
|
|||||||
"norm.linear": "modulation_lin",
|
"norm.linear": "modulation_lin",
|
||||||
"to_qkv_mlp": "linear1",
|
"to_qkv_mlp": "linear1",
|
||||||
"proj_out": "linear2",
|
"proj_out": "linear2",
|
||||||
|
|
||||||
"norm1_a.linear": "img_mod_lin",
|
"norm1_a.linear": "img_mod_lin",
|
||||||
"norm1_b.linear": "txt_mod_lin",
|
"norm1_b.linear": "txt_mod_lin",
|
||||||
"attn.a_to_qkv": "img_attn_qkv",
|
"attn.a_to_qkv": "img_attn_qkv",
|
||||||
@@ -373,7 +738,7 @@ class FluxLoRAConverter:
|
|||||||
"lora_A.weight": "lora_down.weight",
|
"lora_A.weight": "lora_down.weight",
|
||||||
}
|
}
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
for name, param in state_dict.items():
|
for name, param in tqdm(state_dict.items(), desc="Aligning to opensource format"):
|
||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
||||||
names.pop(-2)
|
names.pop(-2)
|
||||||
@@ -387,10 +752,17 @@ class FluxLoRAConverter:
|
|||||||
state_dict_[rename] = param
|
state_dict_[rename] = param
|
||||||
if rename.endswith("lora_up.weight"):
|
if rename.endswith("lora_up.weight"):
|
||||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Conversion complete, output keys: {len(state_dict_)}")
|
||||||
|
else:
|
||||||
|
print(f"Flux LoRA conversion complete: {len(state_dict_)} keys")
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@timing_decorator
|
||||||
def align_to_diffsynth_format(state_dict):
|
def align_to_diffsynth_format(state_dict):
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Converting to diffsynth format, input keys: {len(state_dict)}")
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
@@ -426,7 +798,7 @@ class FluxLoRAConverter:
|
|||||||
return i, name.replace(f"_{i}_", "_blockid_")
|
return i, name.replace(f"_{i}_", "_blockid_")
|
||||||
return None, None
|
return None, None
|
||||||
state_dict_ = {}
|
state_dict_ = {}
|
||||||
for name, param in state_dict.items():
|
for name, param in tqdm(state_dict.items(), desc="Aligning to diffsynth format"):
|
||||||
block_id, source_name = guess_block_id(name)
|
block_id, source_name = guess_block_id(name)
|
||||||
if source_name in rename_dict:
|
if source_name in rename_dict:
|
||||||
target_name = rename_dict[source_name]
|
target_name = rename_dict[source_name]
|
||||||
@@ -434,8 +806,11 @@ class FluxLoRAConverter:
|
|||||||
state_dict_[target_name] = param
|
state_dict_[target_name] = param
|
||||||
else:
|
else:
|
||||||
state_dict_[name] = param
|
state_dict_[name] = param
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Conversion complete, output keys: {len(state_dict_)}")
|
||||||
|
else:
|
||||||
|
print(f"Diffsynth conversion complete: {len(state_dict_)} keys")
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
Reference in New Issue
Block a user