mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
Merge branch 'modelscope:main' into main
This commit is contained in:
@@ -157,6 +157,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
|||||||
|
|
||||||
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
||||||
|
|
||||||
|
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
|
||||||
|
|
||||||
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
||||||
|
|
||||||
Step 5: Test
|
Step 5: Test
|
||||||
|
|||||||
@@ -7,11 +7,12 @@ from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
|||||||
from peft import LoraConfig, inject_adapter_in_model
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextVideoDataset(torch.utils.data.Dataset):
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
|
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
||||||
metadata = pd.read_csv(metadata_path)
|
metadata = pd.read_csv(metadata_path)
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
self.text = metadata["text"].to_list()
|
self.text = metadata["text"].to_list()
|
||||||
@@ -21,6 +22,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
self.is_i2v = is_i2v
|
||||||
|
|
||||||
self.frame_process = v2.Compose([
|
self.frame_process = v2.Compose([
|
||||||
v2.CenterCrop(size=(height, width)),
|
v2.CenterCrop(size=(height, width)),
|
||||||
@@ -48,10 +50,13 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
frames = []
|
frames = []
|
||||||
|
first_frame = None
|
||||||
for frame_id in range(num_frames):
|
for frame_id in range(num_frames):
|
||||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
frame = Image.fromarray(frame)
|
frame = Image.fromarray(frame)
|
||||||
frame = self.crop_and_resize(frame)
|
frame = self.crop_and_resize(frame)
|
||||||
|
if first_frame is None:
|
||||||
|
first_frame = np.array(frame)
|
||||||
frame = frame_process(frame)
|
frame = frame_process(frame)
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
reader.close()
|
reader.close()
|
||||||
@@ -59,7 +64,10 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
frames = torch.stack(frames, dim=0)
|
frames = torch.stack(frames, dim=0)
|
||||||
frames = rearrange(frames, "T C H W -> C T H W")
|
frames = rearrange(frames, "T C H W -> C T H W")
|
||||||
|
|
||||||
return frames
|
if self.is_i2v:
|
||||||
|
return frames, first_frame
|
||||||
|
else:
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def load_video(self, file_path):
|
def load_video(self, file_path):
|
||||||
@@ -78,6 +86,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
def load_image(self, file_path):
|
def load_image(self, file_path):
|
||||||
frame = Image.open(file_path).convert("RGB")
|
frame = Image.open(file_path).convert("RGB")
|
||||||
frame = self.crop_and_resize(frame)
|
frame = self.crop_and_resize(frame)
|
||||||
|
first_frame = frame
|
||||||
frame = self.frame_process(frame)
|
frame = self.frame_process(frame)
|
||||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||||
return frame
|
return frame
|
||||||
@@ -87,10 +96,16 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
text = self.text[data_id]
|
text = self.text[data_id]
|
||||||
path = self.path[data_id]
|
path = self.path[data_id]
|
||||||
if self.is_image(path):
|
if self.is_image(path):
|
||||||
|
if self.is_i2v:
|
||||||
|
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||||
video = self.load_image(path)
|
video = self.load_image(path)
|
||||||
else:
|
else:
|
||||||
video = self.load_video(path)
|
video = self.load_video(path)
|
||||||
data = {"text": text, "video": video, "path": path}
|
if self.is_i2v:
|
||||||
|
video, first_frame = video
|
||||||
|
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||||
|
else:
|
||||||
|
data = {"text": text, "video": video, "path": path}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -100,22 +115,35 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForDataProcess(pl.LightningModule):
|
class LightningModelForDataProcess(pl.LightningModule):
|
||||||
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
model_path = [text_encoder_path, vae_path]
|
||||||
|
if image_encoder_path is not None:
|
||||||
|
model_path.append(image_encoder_path)
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
model_manager.load_models([text_encoder_path, vae_path])
|
model_manager.load_models(model_path)
|
||||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
||||||
|
|
||||||
self.pipe.device = self.device
|
self.pipe.device = self.device
|
||||||
if video is not None:
|
if video is not None:
|
||||||
|
# prompt
|
||||||
prompt_emb = self.pipe.encode_prompt(text)
|
prompt_emb = self.pipe.encode_prompt(text)
|
||||||
|
# video
|
||||||
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
# image
|
||||||
|
if "first_frame" in batch:
|
||||||
|
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||||
|
_, _, num_frames, height, width = video.shape
|
||||||
|
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||||
|
else:
|
||||||
|
image_emb = {}
|
||||||
|
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||||
torch.save(data, path + ".tensors.pth")
|
torch.save(data, path + ".tensors.pth")
|
||||||
|
|
||||||
|
|
||||||
@@ -224,6 +252,11 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
latents = batch["latents"].to(self.device)
|
latents = batch["latents"].to(self.device)
|
||||||
prompt_emb = batch["prompt_emb"]
|
prompt_emb = batch["prompt_emb"]
|
||||||
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||||
|
image_emb = batch["image_emb"]
|
||||||
|
if "clip_feature" in image_emb:
|
||||||
|
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||||
|
if "y" in image_emb:
|
||||||
|
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
self.pipe.device = self.device
|
self.pipe.device = self.device
|
||||||
@@ -236,7 +269,7 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
noise_pred = self.pipe.denoising_model()(
|
noise_pred = self.pipe.denoising_model()(
|
||||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb,
|
||||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||||
)
|
)
|
||||||
@@ -296,6 +329,12 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path of text encoder.",
|
help="Path of text encoder.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of image encoder.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vae_path",
|
"--vae_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -466,7 +505,8 @@ def data_process(args):
|
|||||||
frame_interval=1,
|
frame_interval=1,
|
||||||
num_frames=args.num_frames,
|
num_frames=args.num_frames,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width
|
width=args.width,
|
||||||
|
is_i2v=args.image_encoder_path is not None
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -476,6 +516,7 @@ def data_process(args):
|
|||||||
)
|
)
|
||||||
model = LightningModelForDataProcess(
|
model = LightningModelForDataProcess(
|
||||||
text_encoder_path=args.text_encoder_path,
|
text_encoder_path=args.text_encoder_path,
|
||||||
|
image_encoder_path=args.image_encoder_path,
|
||||||
vae_path=args.vae_path,
|
vae_path=args.vae_path,
|
||||||
tiled=args.tiled,
|
tiled=args.tiled,
|
||||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||||
|
|||||||
Reference in New Issue
Block a user