mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
rearrange lora loading modules
This commit is contained in:
@@ -22,8 +22,8 @@ from ..models.flux_value_control import MultiValueEncoder
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher
|
||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
|
||||
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
@@ -125,18 +125,20 @@ class FluxImagePipeline(BasePipeline):
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str],
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
local_model_path="./models",
|
||||
skip_download=False
|
||||
state_dict=None,
|
||||
):
|
||||
if isinstance(lora_config, str):
|
||||
lora_config = ModelConfig(path=lora_config)
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
lora = state_dict
|
||||
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = loader.convert_state_dict(lora)
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
@@ -150,19 +152,21 @@ class FluxImagePipeline(BasePipeline):
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def enable_lora_patcher(self):
|
||||
if not (hasattr(self, "vram_management_enabled") and self.vram_management_enabled):
|
||||
print("Please enable VRAM management using `enable_vram_management()` before `enable_lora_patcher()`.")
|
||||
return
|
||||
if self.lora_patcher is None:
|
||||
print("Please load lora patcher models before `enable_lora_patcher()`.")
|
||||
return
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
merger_name = name.replace(".", "___")
|
||||
if merger_name in self.lora_patcher.model_dict:
|
||||
module.lora_merger = self.lora_patcher.model_dict[merger_name]
|
||||
|
||||
def load_loras(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_configs: list[Union[ModelConfig, str]],
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
extra_fused_lora=False,
|
||||
):
|
||||
for lora_config in lora_configs:
|
||||
self.load_lora(module, lora_config, hotload=hotload, alpha=alpha)
|
||||
if extra_fused_lora:
|
||||
lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16)
|
||||
fused_lora = lora_fuser(lora_configs)
|
||||
self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self):
|
||||
for name, module in self.named_modules():
|
||||
@@ -365,16 +369,11 @@ class FluxImagePipeline(BasePipeline):
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
|
||||
local_model_path: str = "./models",
|
||||
skip_download: bool = False,
|
||||
redirect_common_files: bool = True,
|
||||
use_usp=False,
|
||||
):
|
||||
# Download and load models
|
||||
model_manager = ModelManager()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||
model_config.download_if_necessary()
|
||||
model_manager.load_model(
|
||||
model_config.path,
|
||||
device=model_config.offload_device or device,
|
||||
|
||||
Reference in New Issue
Block a user