new wan trainer

This commit is contained in:
Artiprocher
2025-06-06 14:58:41 +08:00
parent 8f10a9c353
commit 62f6ca2b8a
87 changed files with 1779 additions and 1543 deletions

View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..models import ModelManager
from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
@@ -21,6 +21,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
from ..lora import GeneralLoRALoader
@@ -137,7 +138,8 @@ class BasePipeline(torch.nn.Module):
def enable_cpu_offload(self):
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
self.vram_management_enabled = True
def get_free_vram(self):
@@ -183,7 +185,6 @@ class ModelConfig:
self.path = self.path[0]
class WanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
@@ -216,6 +217,12 @@ class WanVideoPipeline(BasePipeline):
]
self.model_fn = model_fn_wan_video
def load_lora(self, module, path, alpha=1):
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
@@ -946,6 +953,7 @@ def model_fn_wan_video(
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1036,7 +1044,14 @@ def model_fn_wan_video(
return custom_forward
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,