Merge pull request #431 from modelscope/wan-update

Wan update
This commit is contained in:
Zhongjie Duan
2025-03-10 18:26:32 +08:00
committed by GitHub
8 changed files with 345 additions and 713 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
return super().forward(x).type_as(x)
class SelfAttention(nn.Module):
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
# output
x = self.proj(x)
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
k, v = self.to_kv(x).chunk(2, dim=-1)
# compute attention
x = flash_attention(q, k, v, version=2)
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
x = x.reshape(b, 1, c)
# output
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
dtype = next(iter(self.model.visual.parameters())).dtype
videos = videos.to(dtype)
out = self.model.visual(videos, use_31_block=True)
return out

View File

@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float().clamp_(-1, 1)
values = values.clamp_(-1, 1)
return values
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float()
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x.float()
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.float().clamp_(-1, 1)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):

View File

@@ -14,7 +14,7 @@ from tqdm import tqdm
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
from ..models.wan_video_dit import RMSNorm
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
@@ -60,8 +60,7 @@ class WanVideoPipeline(BasePipeline):
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
WanLayerNorm: AutoWrappedModule,
WanRMSNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -116,7 +115,7 @@ class WanVideoPipeline(BasePipeline):
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_dtype=dtype,
computation_device=self.device,
),
)
@@ -153,17 +152,21 @@ class WanVideoPipeline(BasePipeline):
def encode_image(self, image, num_frames, height, width):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
y = torch.concat([msk, y])
return {"clip_fea": clip_context, "y": [y]}
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
y = y.to(dtype=self.torch_dtype, device=self.device)
return {"clip_feature": clip_context, "y": y}
def tensor2video(self, frames):
@@ -174,18 +177,16 @@ class WanVideoPipeline(BasePipeline):
def prepare_extra_input(self, latents=None):
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
@@ -224,12 +225,13 @@ class WanVideoPipeline(BasePipeline):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
# Initialize noise
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
noise = noise.to(dtype=self.torch_dtype, device=self.device)
if input_video is not None:
self.load_models_to_device(['vae'])
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise
@@ -252,20 +254,19 @@ class WanVideoPipeline(BasePipeline):
# Denoise
self.load_models_to_device(["dit"])
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Inference
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae'])

View File

@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = self.text_encoder(ids, mask)
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
for i, v in enumerate(seq_lens):
prompt_emb[:, v:] = 0
return prompt_emb

View File

@@ -155,6 +155,10 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--use_gradient_checkpointing
```
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
Step 5: Test
Test LoRA:

View File

@@ -113,6 +113,7 @@ class LightningModelForDataProcess(pl.LightningModule):
self.pipe.device = self.device
if video is not None:
prompt_emb = self.pipe.encode_prompt(text)
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
data = {"latents": latents, "prompt_emb": prompt_emb}
torch.save(data, path + ".tensors.pth")
@@ -145,10 +146,21 @@ class TensorDataset(torch.utils.data.Dataset):
class LightningModelForTrain(pl.LightningModule):
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
pretrained_lora_path=None
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([dit_path])
if os.path.isfile(dit_path):
model_manager.load_models([dit_path])
else:
dit_path = dit_path.split(",")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -167,6 +179,7 @@ class LightningModelForTrain(pl.LightningModule):
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def freeze_parameters(self):
@@ -210,24 +223,25 @@ class LightningModelForTrain(pl.LightningModule):
# Data
latents = batch["latents"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
# Loss
self.pipe.device = self.device
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
@@ -410,6 +424,12 @@ def parse_args():
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to use gradient checkpointing offload.",
)
parser.add_argument(
"--train_architecture",
type=str,
@@ -490,6 +510,7 @@ def train(args):
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
@@ -510,6 +531,7 @@ def train(args):
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,

View File

@@ -11,7 +11,7 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
model_manager = ModelManager(device="cpu")
model_manager.load_models(
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
torch_dtype=torch.float16, # Image Encoder is loaded with float16
torch_dtype=torch.float32, # Image Encoder is loaded with float32
)
model_manager.load_models(
[