training framework

This commit is contained in:
Artiprocher
2025-05-12 17:48:28 +08:00
parent dbef6122e9
commit 675eefa07e
20 changed files with 939 additions and 174 deletions

View File

@@ -1,34 +1,26 @@
import torch, warnings, glob
import torch, warnings, glob, os
import numpy as np
from PIL import Image
from einops import repeat, reduce
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
import types
from ..models import ModelManager
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import WanPrompter
import torch, os
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models import ModelManager
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
@@ -50,6 +42,16 @@ class BasePipeline(torch.nn.Module):
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.vram_management_enabled = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None):
@@ -135,8 +137,20 @@ class BasePipeline(torch.nn.Module):
def enable_cpu_offload(self):
warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device")
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
def get_free_vram(self):
total_memory = torch.cuda.get_device_properties(self.device).total_memory
allocated_memory = torch.cuda.device_memory_used(self.device)
return (total_memory - allocated_memory) / (1024 ** 3)
def freeze_except(self, model_names):
for name, model in self.named_children():
if name not in model_names:
model.eval()
model.requires_grad_(False)
@dataclass
@@ -146,17 +160,19 @@ class ModelConfig:
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
offload_device: Optional[Union[str, torch.device]] = None
quantization_dtype: Optional[torch.dtype] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./models", skip_download=False):
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
@@ -195,10 +211,36 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
self.model_fn = model_fn_wan_video
def train(self):
super().train()
self.scheduler.set_timesteps(1000, training=True)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_free_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
@@ -217,9 +259,11 @@ class WanVideoPipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
@@ -233,7 +277,7 @@ class WanVideoPipeline(BasePipeline):
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
@@ -246,6 +290,7 @@ class WanVideoPipeline(BasePipeline):
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
@@ -304,6 +349,7 @@ class WanVideoPipeline(BasePipeline):
),
)
if self.vace is not None:
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.vace,
module_map = {
@@ -316,10 +362,11 @@ class WanVideoPipeline(BasePipeline):
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@@ -330,8 +377,23 @@ class WanVideoPipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
local_model_path: str = "./models",
skip_download: bool = False
skip_download: bool = False,
redirect_common_files: bool = True,
):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
}
for model_config in model_configs:
if model_config.origin_file_pattern is None or model_config.model_id is None:
continue
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
@@ -339,7 +401,7 @@ class WanVideoPipeline(BasePipeline):
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.quantization_dtype or torch_dtype
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
@@ -356,63 +418,54 @@ class WanVideoPipeline(BasePipeline):
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
return pipe
def denoising_model(self):
return self.dit
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image=None,
input_image: Optional[Image.Image] = None,
# First-last-frame-to-video
end_image=None,
end_image: Optional[Image.Image] = None,
# Video-to-video
input_video=None,
denoising_strength=1.0,
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# ControlNet
control_video=None,
reference_image=None,
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
# VACE
vace_video=None,
vace_video_mask=None,
vace_reference_image=None,
vace_scale=1.0,
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Randomness
seed=None,
rand_device="cpu",
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
# Shape
height=480,
width=832,
height: Optional[int] = 480,
width: Optional[int] = 832,
num_frames=81,
# Classifier-free guidance
cfg_scale=5.0,
cfg_merge=False,
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps=50,
sigma_shift=5.0,
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
# Speed control
motion_bucket_id=None,
motion_bucket_id: Optional[int] = None,
# VAE tiling
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Sliding window
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
# Teacache
tea_cache_l1_thresh=None,
tea_cache_model_id="",
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# progress_bar
progress_bar_cmd=tqdm,
):
@@ -452,12 +505,12 @@ class WanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep)
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
if cfg_merge:
noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
else:
noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -467,7 +520,7 @@ class WanVideoPipeline(BasePipeline):
# VACE (TODO: remove it)
if vace_reference_image is not None:
latents = latents[:, :, 1:]
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# Decode
self.load_models_to_device(['vae'])
@@ -558,18 +611,21 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents}
@@ -639,7 +695,7 @@ class WanVideoUnit_FunControl(PipelineUnit):
return {}
pipe.load_models_to_device(self.onload_model_names)
control_video = pipe.preprocess_video(control_video)
control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
@@ -678,7 +734,7 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
def process(self, pipe: WanVideoPipeline, motion_bucket_id):
if motion_bucket_id is None:
return {}
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_bucket_id": motion_bucket_id}
@@ -703,18 +759,16 @@ class WanVideoUnit_VACE(PipelineUnit):
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
else:
vace_video = pipe.preprocess_video(vace_video)
vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_mask is None:
vace_mask = torch.ones_like(vace_video)
else:
vace_mask = pipe.preprocess_video(vace_mask)
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
@@ -724,8 +778,7 @@ class WanVideoUnit_VACE(PipelineUnit):
pass
else:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
@@ -894,6 +947,7 @@ def model_fn_wan_video(
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
use_gradient_checkpointing: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -978,8 +1032,20 @@ def model_fn_wan_video(
if tea_cache_update:
x = tea_cache.update(x)
else:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(dit.blocks):
x = block(x, context, t_mod, freqs)
if use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
if tea_cache is not None:

View File

@@ -35,6 +35,9 @@ class FlowMatchScheduler():
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):

190
diffsynth/trainers/utils.py Normal file
View File

@@ -0,0 +1,190 @@
import imageio, os, torch, warnings, torchvision
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
from tqdm import tqdm
from accelerate import Accelerator
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path, metadata_path,
frame_interval=1, num_frames=81,
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
data_file_keys=("video",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
repeat=1,
):
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
self.base_path = base_path
self.frame_interval = frame_interval
self.num_frames = num_frames
self.dynamic_resolution = dynamic_resolution
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.video_file_extension = video_file_extension
self.repeat = repeat
if height is not None and width is not None and dynamic_resolution == True:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames):
reader = imageio.get_reader(file_path)
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
frames.append(frame)
reader.close()
return frames
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_video(self, file_path):
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
return file_ext_name.lower() in self.image_file_extension
def is_video(self, file_path):
file_ext_name = file_path.split(".")[-1]
return file_ext_name.lower() in self.video_file_extension
def load_data(self, file_path):
if self.is_image(file_path):
return self.load_image(file_path)
elif self.is_video(file_path):
return self.load_video(file_path)
else:
return None
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
if lora_alpha is None:
lora_alpha = lora_rank
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
return model
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate, num_epochs, output_path, remove_prefix=None):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
accelerator = Accelerator(gradient_accumulation_steps=1)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
trainable_param_names = model.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix):]
state_dict_[name] = param
path = os.path.join(output_path, f"epoch-{epoch}")
accelerator.save(state_dict_, path, safe_serialization=True)

View File

@@ -8,8 +8,32 @@ def cast_to(weight, dtype, device):
return r
class AutoWrappedModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
class AutoTorchModule(torch.nn.Module):
def __init__(self):
super().__init__()
def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def keep(self):
if self.state != 2:
self.to(dtype=self.computation_dtype, device=self.computation_device)
self.state = 2
class AutoWrappedModule(AutoTorchModule):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
@@ -18,28 +42,25 @@ class AutoWrappedModule(torch.nn.Module):
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
if self.state == 2:
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class WanAutoCastLayerNorm(torch.nn.LayerNorm):
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
@@ -50,31 +71,28 @@ class WanAutoCastLayerNorm(torch.nn.LayerNorm):
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
if self.state == 2:
weight, bias = self.weight, self.bias
else:
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
weight, bias = self.weight, self.bias
else:
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
with torch.amp.autocast(device_type=x.device.type):
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
return x
class AutoWrappedLinear(torch.nn.Linear):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
@@ -85,28 +103,25 @@ class AutoWrappedLinear(torch.nn.Linear):
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def offload(self):
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def forward(self, x, *args, **kwargs):
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
if self.state == 2:
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None):
for name, module in model.named_children():
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
@@ -115,16 +130,16 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict,
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_)
module_ = target_module(module, **module_config_, vram_limit=vram_limit)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
model.vram_management_enabled = True