mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
bug fix
This commit is contained in:
@@ -46,7 +46,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
},
|
||||
"diffsynth.models.wan_video_dit.WanModel": {
|
||||
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
|
||||
@@ -198,6 +198,66 @@ class AutoWrappedModule(AutoTorchModule):
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
module,
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
name,
|
||||
disk_map,
|
||||
**kwargs
|
||||
)
|
||||
if self.disk_offload:
|
||||
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||
if copy_module:
|
||||
module = copy.deepcopy(self.module)
|
||||
else:
|
||||
module = self.module
|
||||
state_dict = {}
|
||||
for name in self.required_params:
|
||||
param = self.disk_map[self.param_name(name)]
|
||||
param = param.to(dtype=torch_dtype, device=device)
|
||||
state_dict[name] = param
|
||||
module.load_state_dict(state_dict, assign=True, strict=False)
|
||||
return module
|
||||
|
||||
def offload_to_disk(self, model: torch.nn.Module):
|
||||
for name in self.required_params:
|
||||
getattr(self, name).to("meta")
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__ or name == "module":
|
||||
return super().__getattr__(name)
|
||||
else:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -366,11 +426,15 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
||||
if isinstance(model, AutoWrappedNonRecurseModule):
|
||||
model = model.module
|
||||
for name, module in model.named_children():
|
||||
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
||||
if isinstance(module_, AutoWrappedNonRecurseModule):
|
||||
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||
setattr(model, name, module_)
|
||||
break
|
||||
else:
|
||||
|
||||
@@ -220,7 +220,7 @@ class BasePipeline(torch.nn.Module):
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
hotload=None,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
@@ -233,12 +233,15 @@ class BasePipeline(torch.nn.Module):
|
||||
lora = state_dict
|
||||
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = lora_loader.convert_state_dict(lora)
|
||||
if hotload is None:
|
||||
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||
if hotload:
|
||||
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||
updated_num = 0
|
||||
for name, module in module.named_modules():
|
||||
for _, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
name = module.name
|
||||
lora_a_name = f'{name}.lora_A.weight'
|
||||
lora_b_name = f'{name}.lora_B.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
|
||||
@@ -15,10 +15,10 @@ from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_dit_s2v import rope_precompute
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
@@ -526,13 +526,13 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
super().__init__(
|
||||
input_params=("reference_image", "height", "width", "reference_image"),
|
||||
output_params=("reference_latents", "clip_feature"),
|
||||
onload_model_names=("vae",)
|
||||
onload_model_names=("vae", "image_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
||||
if reference_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(["vae"])
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_latents = pipe.preprocess_video([reference_image])
|
||||
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
|
||||
|
||||
Reference in New Issue
Block a user