mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
new wan trainer
This commit is contained in:
@@ -11,7 +11,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..models import ModelManager
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
@@ -21,6 +21,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
from ..lora import GeneralLoRALoader
|
||||
|
||||
|
||||
|
||||
@@ -137,7 +138,8 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
|
||||
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_free_vram(self):
|
||||
@@ -183,7 +185,6 @@ class ModelConfig:
|
||||
self.path = self.path[0]
|
||||
|
||||
|
||||
|
||||
class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
|
||||
@@ -216,6 +217,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
@@ -946,6 +953,7 @@ def model_fn_wan_video(
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -1036,7 +1044,14 @@ def model_fn_wan_video(
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
|
||||
Reference in New Issue
Block a user