mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
new wan trainer
This commit is contained in:
45
diffsynth/lora/__init__.py
Normal file
45
diffsynth/lora/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class GeneralLoRALoader:
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
updated_num = 0
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
state_dict = module.state_dict()
|
||||
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||
module.load_state_dict(state_dict)
|
||||
updated_num += 1
|
||||
print(f"{updated_num} tensors are updated by LoRA.")
|
||||
@@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
with safe_open(file_path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
@@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -11,7 +11,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..models import ModelManager
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
@@ -21,6 +21,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
from ..lora import GeneralLoRALoader
|
||||
|
||||
|
||||
|
||||
@@ -137,7 +138,8 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.")
|
||||
warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.")
|
||||
self.vram_management_enabled = True
|
||||
|
||||
|
||||
def get_free_vram(self):
|
||||
@@ -183,7 +185,6 @@ class ModelConfig:
|
||||
self.path = self.path[0]
|
||||
|
||||
|
||||
|
||||
class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
|
||||
@@ -216,6 +217,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
@@ -946,6 +953,7 @@ def model_fn_wan_video(
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -1036,7 +1044,14 @@ def model_fn_wan_video(
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
|
||||
@@ -25,6 +25,7 @@ class VideoDataset(torch.utils.data.Dataset):
|
||||
metadata_path = args.dataset_metadata_path
|
||||
height = args.height
|
||||
width = args.width
|
||||
num_frames = args.num_frames
|
||||
data_file_keys = args.data_file_keys.split(",")
|
||||
repeat = args.dataset_repeat
|
||||
|
||||
@@ -205,27 +206,52 @@ def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = model.export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
path = os.path.join(output_path, f"epoch-{epoch}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
|
||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0])
|
||||
accelerator = Accelerator()
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
with torch.no_grad():
|
||||
inputs = model.forward_preprocess(data)
|
||||
inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs}
|
||||
torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth"))
|
||||
|
||||
|
||||
|
||||
def wan_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.")
|
||||
parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in each video. The frames are sampled from the prefix.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default="", help="Model paths to be loaded. JSON format.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Model paths to be loaded. JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--task", type=str, default="train_lora", choices=["train_lora", "train_full"], help="Task.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Layers with LoRA modules.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Trainable models, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Add LoRA on which model.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Add LoRA on which layer.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank.")
|
||||
parser.add_argument("--input_contains_input_image", default=False, action="store_true", help="Model input contains 'input_image'.")
|
||||
parser.add_argument("--input_contains_end_image", default=False, action="store_true", help="Model input contains 'end_image'.")
|
||||
parser.add_argument("--input_contains_control_video", default=False, action="store_true", help="Model input contains 'control_video'.")
|
||||
parser.add_argument("--input_contains_reference_image", default=False, action="store_true", help="Model input contains 'reference_image'.")
|
||||
parser.add_argument("--input_contains_vace_video", default=False, action="store_true", help="Model input contains 'vace_video'.")
|
||||
parser.add_argument("--input_contains_vace_reference_image", default=False, action="store_true", help="Model input contains 'vace_reference_image'.")
|
||||
parser.add_argument("--input_contains_motion_bucket_id", default=False, action="store_true", help="Model input contains 'motion_bucket_id'.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Offload gradient checkpointing to RAM.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
|
||||
|
||||
class AutoWrappedModule(AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
|
||||
super().__init__()
|
||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||
self.offload_dtype = offload_dtype
|
||||
@@ -60,7 +60,7 @@ class AutoWrappedModule(AutoTorchModule):
|
||||
|
||||
|
||||
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
|
||||
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
@@ -92,7 +92,7 @@ class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
@@ -105,6 +105,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
self.state = 0
|
||||
self.name = name
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.state == 2:
|
||||
@@ -121,8 +122,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None):
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
|
||||
for name, module in model.named_children():
|
||||
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
num_param = sum(p.numel() for p in module.parameters())
|
||||
@@ -130,12 +132,12 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict,
|
||||
module_config_ = overflow_module_config
|
||||
else:
|
||||
module_config_ = module_config
|
||||
module_ = target_module(module, **module_config_, vram_limit=vram_limit)
|
||||
module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name)
|
||||
setattr(model, name, module_)
|
||||
total_num_param += num_param
|
||||
break
|
||||
else:
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit)
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name)
|
||||
return total_num_param
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user