This commit is contained in:
Artiprocher
2025-11-15 16:47:13 +08:00
parent e3356556ee
commit ea0a5c5908
72 changed files with 481 additions and 472 deletions

View File

@@ -46,7 +46,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
},
"diffsynth.models.wan_video_dit.WanModel": {
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",

View File

@@ -198,6 +198,66 @@ class AutoWrappedModule(AutoTorchModule):
return getattr(self.module, name)
class AutoWrappedNonRecurseModule(AutoWrappedModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
module,
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
name,
disk_map,
**kwargs
)
if self.disk_offload:
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True, strict=False)
return module
def offload_to_disk(self, model: torch.nn.Module):
for name in self.required_params:
getattr(self, name).to("meta")
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(
self,
@@ -366,11 +426,15 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
if isinstance(model, AutoWrappedNonRecurseModule):
model = model.module
for name, module in model.named_children():
layer_name = name if name_prefix == "" else name_prefix + "." + name
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
if isinstance(module_, AutoWrappedNonRecurseModule):
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
setattr(model, name, module_)
break
else:

View File

@@ -220,7 +220,7 @@ class BasePipeline(torch.nn.Module):
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
hotload=None,
state_dict=None,
):
if state_dict is None:
@@ -233,12 +233,15 @@ class BasePipeline(torch.nn.Module):
lora = state_dict
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
lora = lora_loader.convert_state_dict(lora)
if hotload is None:
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
if hotload:
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
updated_num = 0
for name, module in module.named_modules():
for _, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
name = module.name
lora_a_name = f'{name}.lora_A.weight'
lora_b_name = f'{name}.lora_B.weight'
if lora_a_name in lora and lora_b_name in lora:

View File

@@ -15,10 +15,10 @@ from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d
from ..models.wan_video_dit_s2v import rope_precompute
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -526,13 +526,13 @@ class WanVideoUnit_FunReference(PipelineUnit):
super().__init__(
input_params=("reference_image", "height", "width", "reference_image"),
output_params=("reference_latents", "clip_feature"),
onload_model_names=("vae",)
onload_model_names=("vae", "image_encoder")
)
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
if reference_image is None:
return {}
pipe.load_models_to_device(["vae"])
pipe.load_models_to_device(self.onload_model_names)
reference_image = reference_image.resize((width, height))
reference_latents = pipe.preprocess_video([reference_image])
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)