diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 2315e96..880bb76 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -1,4 +1,11 @@ import torch +import time +from tqdm import tqdm +import psutil +import gc +import os +import platform +import multiprocessing from .sd_unet import SDUNet from .sdxl_unet import SDXLUNet from .sd_text_encoder import SDTextEncoder @@ -10,7 +17,67 @@ from .cog_dit import CogDiT from .hunyuan_video_dit import HunyuanVideoDiT from .wan_video_dit import WanModel +# Global debug variable: when set to False, only minimal info is printed. +DEBUG = False +def debug_print(*args, **kwargs): + """Print debug messages only if DEBUG is True.""" + if DEBUG: + print(*args, **kwargs) + +def timing_decorator(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + if DEBUG: + print(f"⏱️ {func.__name__} took {elapsed_time:.4f} seconds") + return result + return wrapper + +def memory_usage(): + """Get current memory usage of the process""" + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + return f"{memory_info.rss / (1024 * 1024):.1f} MB" + +def optimize_cpu_threading(): + """Set optimal thread configuration for the current CPU""" + cpu_count = multiprocessing.cpu_count() + + # Get processor information + processor = platform.processor().lower() + + if "amd" in processor: + optimal_threads = max(1, cpu_count) + else: # Intel or other + optimal_threads = max(1, cpu_count // 2) + + os.environ["OMP_NUM_THREADS"] = str(optimal_threads) + os.environ["MKL_NUM_THREADS"] = str(optimal_threads) + + blas_info = "unknown" + try: + import torch.__config__ + config_info = torch.__config__.show() + if "mkl" in config_info.lower(): + blas_info = "MKL" + elif "openblas" in config_info.lower(): + blas_info = "OpenBLAS" + except: + pass + + if DEBUG: + print(f"CPU Optimization: {platform.processor()}") + print(f"Physical cores: {cpu_count // 2}, Total threads: {cpu_count}") + print(f"Using {optimal_threads} threads for computation") + print(f"BLAS backend: {blas_info}") + else: + print(f"CPU threading optimized: using {optimal_threads} threads") + + torch.set_num_threads(optimal_threads) + return optimal_threads class LoRAFromCivitai: def __init__(self): @@ -18,110 +85,327 @@ class LoRAFromCivitai: self.lora_prefix = [] self.renamed_lora_prefix = {} self.special_keys = {} + self.stats = { + "tensor_movements_to_gpu": 0, + "tensor_movements_to_cpu": 0, + "lora_weights_processed": 0, + "format_conversions": 0, + } + # Set optimal thread count for CPU operations + self.optimal_threads = optimize_cpu_threading() + self.use_gpu = torch.cuda.is_available() + + # Enable tensor cores for matrix operations if available + if self.use_gpu and hasattr(torch.backends, 'cudnn'): + torch.backends.cudnn.benchmark = True - + @timing_decorator def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): + if DEBUG: + print(f"Converting state dict with prefix {lora_prefix}, memory usage: {memory_usage()}") + # Detect format for key in state_dict: if ".lora_up" in key: + if DEBUG: + print(f"Detected up/down format, keys: {len(state_dict)}") return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha) + if DEBUG: + print(f"Detected A/B format, keys: {len(state_dict)}") return self.convert_state_dict_AB(state_dict, lora_prefix, alpha) - + @timing_decorator def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") state_dict_ = {} + if DEBUG: + print(f"Processing up/down conversion for {len(state_dict)} tensors...") + + # Determine optimal processing device + device = "cuda" if self.use_gpu else "cpu" + torch_dtype = torch.float16 if self.use_gpu else torch.float32 + + # Count applicable keys first + applicable_keys = [] for key in state_dict: - if ".lora_up" not in key: - continue - if not key.startswith(lora_prefix): - continue - weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) - weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) - weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) - lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_up, weight_down) - target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight" - for special_key in self.special_keys: - target_name = target_name.replace(special_key, self.special_keys[special_key]) - state_dict_[target_name] = lora_weight.cpu() + if ".lora_up" in key and key.startswith(lora_prefix): + applicable_keys.append(key) + + # Prepare batches for processing + BATCH_SIZE = 16 # Adjust based on memory constraints + if DEBUG: + print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...") + + with tqdm(total=len(applicable_keys), desc="Converting up/down weights") as pbar: + for i in range(0, len(applicable_keys), BATCH_SIZE): + batch_keys = applicable_keys[i:i+BATCH_SIZE] + for key in batch_keys: + # Track GPU tensor movements + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device=device, dtype=torch_dtype) + self.stats["tensor_movements_to_gpu"] += 2 + + # Matrix multiplication - faster on GPU, or optimized CPU + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + + target_key = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight" + state_dict_[target_key] = lora_weight.cpu() + self.stats["tensor_movements_to_cpu"] += 1 + self.stats["lora_weights_processed"] += 1 + + # Apply special key replacements + for special_key in self.special_keys: + if special_key in target_key: + state_dict_[target_key] = state_dict_[target_key].replace(special_key, self.special_keys[special_key]) + + pbar.update(1) + + # Clear memory after each batch + del weight_up, weight_down, lora_weight + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print(f"Up/down conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}") + else: + print(f"LoRA conversion complete: {len(state_dict_)} tensors processed") return state_dict_ - - def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16): + @timing_decorator + def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0): state_dict_ = {} + # Determine optimal processing device + device = "cuda" if self.use_gpu else "cpu" + torch_dtype = torch.float16 if self.use_gpu else torch.float32 + + # Collect applicable keys first + applicable_keys = [] for key in state_dict: - if ".lora_B." not in key: - continue - if not key.startswith(lora_prefix): - continue - weight_up = state_dict[key].to(device=device, dtype=torch_dtype) - weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_up, weight_down) - keys = key.split(".") - keys.pop(keys.index("lora_B")) - target_name = ".".join(keys) - target_name = target_name[len(lora_prefix):] - state_dict_[target_name] = lora_weight.cpu() + if ".lora_B." in key and key.startswith(lora_prefix): + applicable_keys.append(key) + + # Prepare batches for processing + BATCH_SIZE = 16 # Adjust based on memory constraints + if DEBUG: + print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...") + + with tqdm(total=len(applicable_keys), desc="Converting A/B weights") as pbar: + for i in range(0, len(applicable_keys), BATCH_SIZE): + batch_keys = applicable_keys[i:i+BATCH_SIZE] + for key in batch_keys: + # Load and process tensors + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + self.stats["tensor_movements_to_gpu"] += 2 + + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + + # Extract target name + keys = key.split(".") + keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + target_name = target_name[len(lora_prefix):] + + # Store result + state_dict_[target_name] = lora_weight.cpu() + self.stats["tensor_movements_to_cpu"] += 1 + self.stats["lora_weights_processed"] += 1 + pbar.update(1) + + # Clear memory after each batch + del weight_up, weight_down, lora_weight + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print(f"A/B conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}") + else: + print(f"LoRA conversion complete: {len(state_dict_)} tensors processed") return state_dict_ - + @timing_decorator def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None): - state_dict_model = model.state_dict() + print(f"Starting LoRA loading process for {model.__class__.__name__}...") + + # Measure state dict loading time - use direct parameter access + start_state_dict = time.time() + state_dict_model = {} + for name, param in model.named_parameters(): + state_dict_model[name] = param + end_state_dict = time.time() + if DEBUG: + print(f"⏱️ Loading model parameters took {end_state_dict - start_state_dict:.4f} seconds, size: {len(state_dict_model)} tensors") + else: + print(f"Model parameters mapped: {len(state_dict_model)} parameters") + + # Measure LoRA conversion time + start_convert = time.time() state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha) - if model_resource == "diffusers": - state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora) - elif model_resource == "civitai": - state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora) + self.stats["format_conversions"] += 1 + end_convert = time.time() + if DEBUG: + print(f"⏱️ LoRA conversion took {end_convert - start_convert:.4f} seconds") + + # Measure format conversion time if applicable + if model_resource: + if DEBUG: + print(f"Converting format from {model_resource}...") + start_format = time.time() + if model_resource == "diffusers": + state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora) + elif model_resource == "civitai": + state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora) + self.stats["format_conversions"] += 1 + end_format = time.time() + if DEBUG: + print(f"⏱️ Format conversion took {end_format - start_format:.4f} seconds") + if isinstance(state_dict_lora, tuple): state_dict_lora = state_dict_lora[0] + if len(state_dict_lora) > 0: - print(f" {len(state_dict_lora)} tensors are updated.") - for name in state_dict_lora: - fp8=False - if state_dict_model[name].dtype == torch.float8_e4m3fn: - state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype) - fp8=True - state_dict_model[name] += state_dict_lora[name].to( - dtype=state_dict_model[name].dtype, device=state_dict_model[name].device) - if fp8: - state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn) - model.load_state_dict(state_dict_model) + if DEBUG: + print(f"Applying {len(state_dict_lora)} LoRA tensors to model weights...") + else: + print("Applying LoRA weights...") + + # Process in batches + BATCH_SIZE = 32 + lora_keys = list(state_dict_lora.keys()) + + start_update = time.time() + with tqdm(total=len(lora_keys), desc="Applying LoRA weights") as pbar: + for i in range(0, len(lora_keys), BATCH_SIZE): + batch_keys = lora_keys[i:i+BATCH_SIZE] + for name in batch_keys: + if name not in state_dict_model: + pbar.update(1) + continue + + param = state_dict_model[name] + + # Handle FP8 tensors + fp8 = False + if param.dtype == torch.float8_e4m3fn: + param_data = param.to(state_dict_lora[name].dtype) + fp8 = True + else: + param_data = param.data + + # Apply direct update (avoids load_state_dict overhead) + param.data = param_data + state_dict_lora[name].to( + dtype=param_data.dtype, device=param_data.device) + + if fp8: + param.data = param.data.to(torch.float8_e4m3fn) + + pbar.update(1) + + # Clear memory after each batch + if self.use_gpu: + torch.cuda.empty_cache() + + end_update = time.time() + if DEBUG: + print(f"⏱️ Weight update took {end_update - start_update:.4f} seconds") + else: + print("Weight update complete.") + else: + print("No LoRA tensors to apply!") + + if DEBUG: + print("\n==== LoRA LOADING STATISTICS ====") + print(f"Total tensor movements to GPU: {self.stats['tensor_movements_to_gpu']}") + print(f"Total tensor movements to CPU: {self.stats['tensor_movements_to_cpu']}") + print(f"Total LoRA weights processed: {self.stats['lora_weights_processed']}") + print(f"Total format conversions: {self.stats['format_conversions']}") + print(f"Final memory usage: {memory_usage()}") + print("================================") + else: + print(f"LoRA load complete: {self.stats['lora_weights_processed']} weights processed, GPU moves: {self.stats['tensor_movements_to_gpu']}, CPU moves: {self.stats['tensor_movements_to_cpu']}.") + + # Clear temporary data and run garbage collection + del state_dict_lora + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - + @timing_decorator def match(self, model, state_dict_lora): - for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes): + if DEBUG: + print(f"Trying to match LoRA format for {model.__class__.__name__}, memory usage: {memory_usage()}") + match_results = [] + + for i, (lora_prefix, model_class) in enumerate(zip(self.lora_prefix, self.supported_model_classes)): if not isinstance(model, model_class): continue - state_dict_model = model.state_dict() + + if DEBUG: + print(f"Checking prefix '{lora_prefix}' for model class {model_class.__name__}") + + # Get parameter names + param_names = set(name for name, _ in model.named_parameters()) + for model_resource in ["diffusers", "civitai"]: try: + if DEBUG: + print(f" Attempting {model_resource} format...") + start_time = time.time() + + # Try conversion state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0) converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \ else model.__class__.state_dict_converter().from_civitai state_dict_lora_ = converter_fn(state_dict_lora_) + if isinstance(state_dict_lora_, tuple): state_dict_lora_ = state_dict_lora_[0] + if len(state_dict_lora_) == 0: + if DEBUG: + print(f" ❌ No matching tensors found for {model_resource} format") continue - for name in state_dict_lora_: - if name not in state_dict_model: + + # Verify the keys actually match the model (sample check) + valid_keys = 0 + for name in list(state_dict_lora_.keys())[:10]: + if name in param_names: + valid_keys += 1 + else: + if DEBUG: + print(f" ⚠️ Key not found in model: {name}") break - else: + + end_time = time.time() + + if valid_keys > 0: + if DEBUG: + print(f" ✅ Match found! Prefix: {lora_prefix}, Format: {model_resource}, Valid keys: {valid_keys}") + print(f" ⏱️ Match verification took {end_time - start_time:.4f} seconds") + else: + print("Matching format found.") return lora_prefix, model_resource - except: - pass + else: + if DEBUG: + print(f" ❌ No valid keys found for this format") + + except Exception as e: + if DEBUG: + print(f" ❌ Error during matching: {str(e)}") + if DEBUG: + print("❌ No match found for any format or prefix") return None - - +# Specialized classes derived from LoRAFromCivitai class SDLoRAFromCivitai(LoRAFromCivitai): def __init__(self): super().__init__() @@ -148,7 +432,6 @@ class SDLoRAFromCivitai(LoRAFromCivitai): "output.blocks": "model.diffusion_model.output_blocks", } - class SDXLLoRAFromCivitai(LoRAFromCivitai): def __init__(self): super().__init__() @@ -176,7 +459,6 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai): "output.blocks": "model.diffusion_model.output_blocks", "2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers" } - class FluxLoRAFromCivitai(LoRAFromCivitai): def __init__(self): @@ -195,14 +477,29 @@ class FluxLoRAFromCivitai(LoRAFromCivitai): "txt.mod": "txt_mod", } +class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): + def __init__(self): + super().__init__() + self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT] + self.lora_prefix = ["diffusion_model.", "transformer."] + self.special_keys = {} class GeneralLoRAFromPeft: def __init__(self): self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel] - - def fetch_device_dtype_from_state_dict(self, target_param): - """Get device and dtype from a parameter""" - return target_param.device, target_param.dtype + self.stats = { + "tensor_movements_to_gpu": 0, + "tensor_movements_to_cpu": 0, + "lora_weights_processed": 0, + "parameter_updates": 0, + } + # Set optimal thread count for CPU operations + self.optimal_threads = optimize_cpu_threading() + self.use_gpu = torch.cuda.is_available() + + # Enable tensor cores for matrix operations if available + if self.use_gpu and hasattr(torch.backends, 'cudnn'): + torch.backends.cudnn.benchmark = True def _get_target_name(self, key): """Extract target parameter name from LoRA key""" @@ -215,105 +512,170 @@ class GeneralLoRAFromPeft: target_name = target_name[len("diffusion_model."):] return target_name + @timing_decorator def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}): - """Original method kept for compatibility with match method""" - device, torch_dtype = None, None - for name, param in target_state_dict.items(): - device, torch_dtype = param.device, param.dtype - break - - if torch_dtype == torch.float8_e4m3fn: - torch_dtype = torch.float32 + if DEBUG: + print(f"Converting state dict with GeneralLoRAFromPeft, memory: {memory_usage()}") + device = "cuda" if self.use_gpu else "cpu" + torch_dtype = torch.float16 if self.use_gpu else torch.float32 state_dict_ = {} - for key in state_dict: - if ".lora_B." not in key: - continue + + # Count applicable keys + applicable_keys = [key for key in state_dict if ".lora_B." in key] + + # Process in batches + BATCH_SIZE = 16 + if DEBUG: + print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...") + + with tqdm(total=len(applicable_keys), desc="Converting LoRA weights") as pbar: + for i in range(0, len(applicable_keys), BATCH_SIZE): + batch_keys = applicable_keys[i:i+BATCH_SIZE] + for key in batch_keys: + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + self.stats["tensor_movements_to_gpu"] += 2 + + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + + target_name = self._get_target_name(key) + + if target_state_dict and target_name not in target_state_dict: + pbar.update(1) + continue + + state_dict_[target_name] = lora_weight.cpu() + self.stats["tensor_movements_to_cpu"] += 1 + self.stats["lora_weights_processed"] += 1 + pbar.update(1) - weight_up = state_dict[key].to(device=device, dtype=torch_dtype) - weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) - - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_up, weight_down) + # Clear memory after each batch + del weight_up, weight_down, lora_weight + if self.use_gpu: + torch.cuda.empty_cache() - target_name = self._get_target_name(key) - - if target_name not in target_state_dict: - return {} - - state_dict_[target_name] = lora_weight.cpu() - + if DEBUG: + print(f"Conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}") + else: + print(f"General LoRA conversion complete: {len(state_dict_)} weights processed") return state_dict_ + @timing_decorator def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): - """Apply LoRA weights directly to model parameters without loading entire state dict""" - # Create parameter name mapping for faster lookup - param_dict = {} - for name, param in model.named_parameters(): - param_dict[name] = param + """Apply LoRA weights directly to model parameters with batched processing""" + print(f"Starting optimized LoRA loading for {model.__class__.__name__}...") - # Process each LoRA parameter pair - modified_count = 0 - for key in state_dict_lora: - if ".lora_B." not in key: - continue - - # Get target parameter name and make sure the parameter exists + # Create parameter lookup dict + start_map = time.time() + param_dict = {name: param for name, param in model.named_parameters()} + end_map = time.time() + if DEBUG: + print(f"⏱️ Parameter mapping took {end_map - start_map:.4f} seconds, found {len(param_dict)} parameters") + else: + print(f"Mapped {len(param_dict)} model parameters") + + # Count applicable LoRA parameters + lora_b_keys = [key for key in state_dict_lora if ".lora_B." in key] + print(f"Found {len(lora_b_keys)} LoRA parameter pairs to process") + + # Group parameters by shape for better memory access patterns + shape_groups = {} + for key in lora_b_keys: target_name = self._get_target_name(key) if target_name not in param_dict: continue - - # Get the target parameter - param = param_dict[target_name] - - # Calculate LoRA weight update - device, dtype = param.device, param.dtype - dtype_for_calc = torch.float32 if dtype == torch.float8_e4m3fn else dtype - - # Process weights and calculate LoRA update - weight_b = state_dict_lora[key].to(device=device, dtype=dtype_for_calc) - weight_a = state_dict_lora[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=dtype_for_calc) - - if len(weight_b.shape) == 4: - weight_b = weight_b.squeeze(3).squeeze(2) - weight_a = weight_a.squeeze(3).squeeze(2) - lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_b, weight_a) - - # Apply update to parameter - if dtype == torch.float8_e4m3fn: - param_float = param.to(torch.float32) - param.data = (param_float + lora_weight).to(dtype) - del param_float - else: - param.data += lora_weight - - # Clean up temporary tensors - del weight_a, weight_b, lora_weight - modified_count += 1 + shape = state_dict_lora[key].shape + if shape not in shape_groups: + shape_groups[shape] = [] + shape_groups[shape].append((key, target_name)) - print(f" {modified_count} tensors are updated.") + if DEBUG: + print(f"Organized into {len(shape_groups)} shape groups for efficient processing") + + # Process each shape group in batches + BATCH_SIZE = 32 + modified_count = 0 + + for shape, key_pairs in shape_groups.items(): + if DEBUG: + print(f"Processing {len(key_pairs)} parameters with shape {shape}") + for i in range(0, len(key_pairs), BATCH_SIZE): + batch = key_pairs[i:i+BATCH_SIZE] + for lora_key, target_name in batch: + param = param_dict[target_name] + dtype_for_calc = torch.float32 if param.dtype == torch.float8_e4m3fn else param.dtype + + # Load weights and compute LoRA update + weight_b = state_dict_lora[lora_key].to(device="cuda" if self.use_gpu else "cpu", dtype=dtype_for_calc) + weight_a = state_dict_lora[lora_key.replace(".lora_B.", ".lora_A.")].to(device="cuda" if self.use_gpu else "cpu", dtype=dtype_for_calc) + self.stats["tensor_movements_to_gpu"] += 2 + + if len(weight_b.shape) == 4: + weight_b = weight_b.squeeze(3).squeeze(2) + weight_a = weight_a.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_b, weight_a) + + # Apply update directly to parameter + if param.dtype == torch.float8_e4m3fn: + param_float = param.to(torch.float32) + param.data = (param_float + lora_weight).to(param.dtype) + del param_float + else: + param.data += lora_weight.to(dtype=param.dtype, device=param.device) + + del weight_a, weight_b, lora_weight + self.stats["parameter_updates"] += 1 + modified_count += 1 + + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print("\n==== OPTIMIZED LORA LOADING STATISTICS ====") + print(f"Total tensor movements to GPU: {self.stats['tensor_movements_to_gpu']}") + print(f"Total LoRA weights processed: {self.stats['lora_weights_processed']}") + print(f"Total parameters updated: {self.stats['parameter_updates']}") + print(f"Final memory usage: {memory_usage()}") + print("==========================================") + else: + print(f"Optimized LoRA load complete: updated {modified_count} tensors") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + print(f"⏱️ {modified_count} tensors were updated successfully") + @timing_decorator def match(self, model, state_dict_lora): - """Check if LoRA parameters match model parameters without loading full state dict""" + """Check if LoRA parameters match model parameters""" + if DEBUG: + print(f"Checking General LoRA compatibility for {model.__class__.__name__}...") + for model_class in self.supported_model_classes: if not isinstance(model, model_class): continue # Create set of parameter names - param_names = set() - for name, _ in model.named_parameters(): - param_names.add(name) + start_param = time.time() + param_names = set(name for name, _ in model.named_parameters()) + end_param = time.time() + if DEBUG: + print(f"⏱️ Parameter name collection took {end_param - start_param:.4f} seconds, found {len(param_names)} names") # Check if a sample of LoRA keys map to model parameters matched_count = 0 checked_count = 0 + start_check = time.time() for key in state_dict_lora: if ".lora_B." not in key: continue @@ -324,30 +686,34 @@ class GeneralLoRAFromPeft: checked_count += 1 if matched_count >= 5: # Found enough matches - return "", "" + break if checked_count >= 50 and matched_count == 0: # Checked enough without matches break + end_check = time.time() + + if DEBUG: + print(f"⏱️ Match check took {end_check - start_check:.4f} seconds") + print(f"Matched {matched_count}/{checked_count} checked parameters") if matched_count > 0: + if DEBUG: + print(f"✅ Compatible with GeneralLoRAFromPeft") + else: + print("LoRA compatibility check: PASS") return "", "" - + if DEBUG: + print("❌ Not compatible with GeneralLoRAFromPeft") return None - -class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): - def __init__(self): - super().__init__() - self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT] - self.lora_prefix = ["diffusion_model.", "transformer."] - self.special_keys = {} - - class FluxLoRAConverter: def __init__(self): pass @staticmethod + @timing_decorator def align_to_opensource_format(state_dict, alpha=1.0): + if DEBUG: + print(f"Converting Flux LoRA to opensource format, input keys: {len(state_dict)}") prefix_rename_dict = { "single_blocks": "lora_unet_single_blocks", "blocks": "lora_unet_double_blocks", @@ -356,7 +722,6 @@ class FluxLoRAConverter: "norm.linear": "modulation_lin", "to_qkv_mlp": "linear1", "proj_out": "linear2", - "norm1_a.linear": "img_mod_lin", "norm1_b.linear": "txt_mod_lin", "attn.a_to_qkv": "img_attn_qkv", @@ -373,7 +738,7 @@ class FluxLoRAConverter: "lora_A.weight": "lora_down.weight", } state_dict_ = {} - for name, param in state_dict.items(): + for name, param in tqdm(state_dict.items(), desc="Aligning to opensource format"): names = name.split(".") if names[-2] != "lora_A" and names[-2] != "lora_B": names.pop(-2) @@ -387,10 +752,17 @@ class FluxLoRAConverter: state_dict_[rename] = param if rename.endswith("lora_up.weight"): state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0] + if DEBUG: + print(f"Conversion complete, output keys: {len(state_dict_)}") + else: + print(f"Flux LoRA conversion complete: {len(state_dict_)} keys") return state_dict_ @staticmethod + @timing_decorator def align_to_diffsynth_format(state_dict): + if DEBUG: + print(f"Converting to diffsynth format, input keys: {len(state_dict)}") rename_dict = { "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", @@ -426,7 +798,7 @@ class FluxLoRAConverter: return i, name.replace(f"_{i}_", "_blockid_") return None, None state_dict_ = {} - for name, param in state_dict.items(): + for name, param in tqdm(state_dict.items(), desc="Aligning to diffsynth format"): block_id, source_name = guess_block_id(name) if source_name in rename_dict: target_name = rename_dict[source_name] @@ -434,8 +806,11 @@ class FluxLoRAConverter: state_dict_[target_name] = param else: state_dict_[name] = param + if DEBUG: + print(f"Conversion complete, output keys: {len(state_dict_)}") + else: + print(f"Diffsynth conversion complete: {len(state_dict_)} keys") return state_dict_ - def get_lora_loaders(): return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]