training framework

This commit is contained in:
Artiprocher
2025-05-12 17:48:28 +08:00
parent dbef6122e9
commit 675eefa07e
20 changed files with 939 additions and 174 deletions

View File

@@ -1,34 +1,26 @@
import torch, warnings, glob
import torch, warnings, glob, os
import numpy as np
from PIL import Image
from einops import repeat, reduce
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
import types
from ..models import ModelManager
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import WanPrompter
import torch, os
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
@@ -50,6 +42,16 @@ class BasePipeline(torch.nn.Module):
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
@@ -135,8 +137,20 @@ class BasePipeline(torch.nn.Module):
def enable_cpu_offload(self):
warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device")
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
def get_free_vram(self):
total_memory = torch.cuda.get_device_properties(self.device).total_memory
allocated_memory = torch.cuda.device_memory_used(self.device)
return (total_memory - allocated_memory) / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name not in model_names:
model.eval()
model.requires_grad_(False)
@dataclass
@@ -146,17 +160,19 @@ class ModelConfig:
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
quantization_dtype: Optional[torch.dtype] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./models", skip_download=False):
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
@@ -195,10 +211,36 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
self.model_fn = model_fn_wan_video
def train(self):
super().train()
self.scheduler.set_timesteps(1000, training=True)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_free_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
@@ -217,9 +259,11 @@ class WanVideoPipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
@@ -233,7 +277,7 @@ class WanVideoPipeline(BasePipeline):
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
@@ -246,6 +290,7 @@ class WanVideoPipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
@@ -304,6 +349,7 @@ class WanVideoPipeline(BasePipeline):
),
)
if self.vace is not None:
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.vace,
module_map = {
@@ -316,10 +362,11 @@ class WanVideoPipeline(BasePipeline):
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@@ -330,8 +377,23 @@ class WanVideoPipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False
skip_download: bool = False,
redirect_common_files: bool = True,
):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
}
for model_config in model_configs:
if model_config.origin_file_pattern is None or model_config.model_id is None:
continue
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
@@ -339,7 +401,7 @@ class WanVideoPipeline(BasePipeline):
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.quantization_dtype or torch_dtype
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
@@ -356,63 +418,54 @@ class WanVideoPipeline(BasePipeline):
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
return pipe
def denoising_model(self):
return self.dit
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image=None,
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image=None,
end_image: Optional[Image.Image] = None,
# Video-to-video
input_video=None,
denoising_strength=1.0,
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# ControlNet
control_video=None,
reference_image=None,
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
# VACE
vace_video=None,
vace_video_mask=None,
vace_reference_image=None,
vace_scale=1.0,
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Randomness
seed=None,
rand_device="cpu",
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height=480,
width=832,
height: Optional[int] = 480,
width: Optional[int] = 832,
num_frames=81,
# Classifier-free guidance
cfg_scale=5.0,
cfg_merge=False,
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps=50,
sigma_shift=5.0,
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# Speed control
motion_bucket_id=None,
motion_bucket_id: Optional[int] = None,
# VAE tiling
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Sliding window
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
# Teacache
tea_cache_l1_thresh=None,
tea_cache_model_id="",
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
):
@@ -452,12 +505,12 @@ class WanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep)
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
if cfg_merge:
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
else:
noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -467,7 +520,7 @@ class WanVideoPipeline(BasePipeline):
# VACE (TODO: remove it)
if vace_reference_image is not None:
latents = latents[:, :, 1:]
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# Decode
self.load_models_to_device(['vae'])
@@ -558,18 +611,21 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
@@ -639,7 +695,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
return {}
pipe.load_models_to_device(self.onload_model_names)
control_video = pipe.preprocess_video(control_video)
control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
@@ -678,7 +734,7 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
if motion_bucket_id is None:
return {}
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_bucket_id": motion_bucket_id}
@@ -703,18 +759,16 @@ class WanVideoUnit_VACE(PipelineUnit):
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
else:
vace_video = pipe.preprocess_video(vace_video)
vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_mask is None:
vace_mask = torch.ones_like(vace_video)
else:
vace_mask = pipe.preprocess_video(vace_mask)
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
@@ -724,8 +778,7 @@ class WanVideoUnit_VACE(PipelineUnit):
pass
else:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
@@ -894,6 +947,7 @@ def model_fn_wan_video(
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -978,8 +1032,20 @@ def model_fn_wan_video(
if tea_cache_update:
x = tea_cache.update(x)
else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(dit.blocks):
x = block(x, context, t_mod, freqs)
if use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
if tea_cache is not None: