Compare commits

...

3 Commits

Author SHA1 Message Date
Artiprocher
040e7fda49 update wan training 2025-06-16 15:46:20 +08:00
Artiprocher
c0706e3fbd support torch<2.6.0 2025-06-16 13:05:54 +08:00
Artiprocher
3d2b51554a update torch version 2025-06-16 12:35:21 +08:00
4 changed files with 109 additions and 47 deletions

View File

@@ -143,10 +143,8 @@ class BasePipeline(torch.nn.Module):
self.vram_management_enabled = True self.vram_management_enabled = True
def get_free_vram(self): def get_vram(self):
total_memory = torch.cuda.get_device_properties(self.device).total_memory return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
allocated_memory = torch.cuda.device_memory_used(self.device)
return (total_memory - allocated_memory) / (1024 ** 3)
def freeze_except(self, model_names): def freeze_except(self, model_names):
@@ -247,7 +245,7 @@ class WanVideoPipeline(BasePipeline):
vram_limit = None vram_limit = None
else: else:
if vram_limit is None: if vram_limit is None:
vram_limit = self.get_free_vram() vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None: if self.text_encoder is not None:
dtype = next(iter(self.text_encoder.parameters())).dtype dtype = next(iter(self.text_encoder.parameters())).dtype

View File

@@ -11,8 +11,9 @@ class VideoDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
base_path=None, metadata_path=None, base_path=None, metadata_path=None,
frame_interval=1, num_frames=81, num_frames=81,
dynamic_resolution=True, max_pixels=1920*1080, height=None, width=None, time_division_factor=4, time_division_remainder=1,
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16, height_division_factor=16, width_division_factor=16,
data_file_keys=("video",), data_file_keys=("video",),
image_file_extension=("jpg", "jpeg", "png", "webp"), image_file_extension=("jpg", "jpeg", "png", "webp"),
@@ -25,17 +26,15 @@ class VideoDataset(torch.utils.data.Dataset):
metadata_path = args.dataset_metadata_path metadata_path = args.dataset_metadata_path
height = args.height height = args.height
width = args.width width = args.width
max_pixels = args.max_pixels
num_frames = args.num_frames num_frames = args.num_frames
data_file_keys = args.data_file_keys.split(",") data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat repeat = args.dataset_repeat
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
self.base_path = base_path self.base_path = base_path
self.frame_interval = frame_interval
self.num_frames = num_frames self.num_frames = num_frames
self.dynamic_resolution = dynamic_resolution self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.max_pixels = max_pixels self.max_pixels = max_pixels
self.height = height self.height = height
self.width = width self.width = width
@@ -46,9 +45,43 @@ class VideoDataset(torch.utils.data.Dataset):
self.video_file_extension = video_file_extension self.video_file_extension = video_file_extension
self.repeat = repeat self.repeat = repeat
if height is not None and width is not None and dynamic_resolution == True: if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.") print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
video_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[:-len(file_ext_name)-1]
if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f:
prompt = f.read().strip()
video_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["video"] = video_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width): def crop_and_resize(self, image, target_height, target_width):
@@ -75,15 +108,22 @@ class VideoDataset(torch.utils.data.Dataset):
height, width = self.height, self.width height, width = self.height, self.width
return height, width return height, width
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def load_frames_using_imageio(self, file_path, start_frame_id, interval, num_frames): def load_video(self, file_path):
reader = imageio.get_reader(file_path) reader = imageio.get_reader(file_path)
if reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: num_frames = self.get_num_frames(reader)
reader.close()
return None
frames = [] frames = []
for frame_id in range(num_frames): for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval) frame = reader.get_data(frame_id)
frame = Image.fromarray(frame) frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame)) frame = self.crop_and_resize(frame, *self.get_height_width(frame))
frames.append(frame) frames.append(frame)
@@ -95,11 +135,6 @@ class VideoDataset(torch.utils.data.Dataset):
image = Image.open(file_path).convert("RGB") image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image)) image = self.crop_and_resize(image, *self.get_height_width(image))
return image return image
def load_video(self, file_path):
frames = self.load_frames_using_imageio(file_path, 0, self.frame_interval, self.num_frames)
return frames
def is_image(self, file_path): def is_image(self, file_path):
@@ -182,34 +217,50 @@ class DiffusionTrainingModule(torch.nn.Module):
def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate=1e-4, num_epochs=1, output_path="./models", remove_prefix_in_ckpt=None, args=None): class ModelLogger:
if args is not None: def __init__(self, output_path, remove_prefix_in_ckpt=None):
learning_rate = args.learning_rate self.output_path = output_path
num_epochs = args.num_epochs self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
output_path = args.output_path
remove_prefix_in_ckpt = args.remove_prefix_in_ckpt
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
accelerator = Accelerator(gradient_accumulation_steps=1) def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch in range(num_epochs): for epoch_id in range(num_epochs):
for data in tqdm(dataloader): for data in tqdm(dataloader):
with accelerator.accumulate(model): with accelerator.accumulate(model):
optimizer.zero_grad() optimizer.zero_grad()
loss = model(data) loss = model(data)
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
scheduler.step() model_logger.on_step_end(loss)
accelerator.wait_for_everyone() scheduler.step()
if accelerator.is_main_process: model_logger.on_epoch_end(accelerator, model, epoch_id)
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt)
os.makedirs(output_path, exist_ok=True)
path = os.path.join(output_path, f"epoch-{epoch}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
@@ -228,8 +279,9 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat
def wan_parser(): def wan_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.") parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.") parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..")
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
@@ -247,5 +299,6 @@ def wan_parser():
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser return parser

View File

@@ -13,7 +13,8 @@ class AutoTorchModule(torch.nn.Module):
super().__init__() super().__init__()
def check_free_vram(self): def check_free_vram(self):
used_memory = torch.cuda.device_memory_used(self.computation_device) / (1024 ** 3) gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
return used_memory < self.vram_limit return used_memory < self.vram_limit
def offload(self): def offload(self):

View File

@@ -1,6 +1,6 @@
import torch, os, json import torch, os, json
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -107,4 +107,14 @@ if __name__ == "__main__":
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs, extra_inputs=args.extra_inputs,
) )
launch_training_task(model, dataset, args=args) model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task(
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
)