mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
wan-series
This commit is contained in:
@@ -274,6 +274,7 @@ class BasePipeline(torch.nn.Module):
|
||||
model_config.path,
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
)
|
||||
return model_pool
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@ import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps)
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps)
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
|
||||
@@ -6,7 +6,7 @@ def add_dataset_base_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||
return parser
|
||||
|
||||
def add_image_size_config(parser: argparse.ArgumentParser):
|
||||
@@ -15,11 +15,19 @@ def add_image_size_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||
return parser
|
||||
|
||||
def add_video_size_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||
return parser
|
||||
|
||||
def add_model_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
||||
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
||||
return parser
|
||||
|
||||
def add_training_config(parser: argparse.ArgumentParser):
|
||||
|
||||
@@ -82,32 +82,55 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
else:
|
||||
return data
|
||||
|
||||
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
||||
if fp8:
|
||||
return {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": device,
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": device,
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
}
|
||||
elif offload:
|
||||
return {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
"clear_parameters": True,
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, device="cpu"):
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
||||
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
||||
fp8_config = {
|
||||
# To accommodate multi-GPU training,
|
||||
# the model will be temporarily stored in CPU memory.
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": device,
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": device,
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
}
|
||||
offload_models = [] if offload_models is None else offload_models.split(",")
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
for path in model_paths:
|
||||
vram_config = fp8_config if path in fp8_models else {}
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=path in fp8_models,
|
||||
offload=path in offload_models,
|
||||
device=device
|
||||
)
|
||||
model_configs.append(ModelConfig(path=path, **vram_config))
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
|
||||
vram_config = fp8_config if model_id_with_origin_path in fp8_models else {}
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=model_id_with_origin_path in fp8_models,
|
||||
offload=model_id_with_origin_path in offload_models,
|
||||
device=device
|
||||
)
|
||||
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
|
||||
return model_configs
|
||||
|
||||
@@ -118,6 +141,7 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
task="sft",
|
||||
):
|
||||
# Scheduler
|
||||
pipe.scheduler.set_timesteps(1000, training=True)
|
||||
@@ -134,7 +158,7 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
# It is delegated to the subclass.
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
if lora_base_model is not None and not task.endswith(":data_process"):
|
||||
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
||||
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user