diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index ff9ce50..f1e5e47 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -291,17 +291,21 @@ class WanModel(torch.nn.Module): clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, **kwargs, ): t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) context = self.text_embedding(context) + if self.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + x, (f, h, w) = self.patchify(x) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), @@ -315,11 +319,19 @@ class WanModel(torch.nn.Module): for block in self.blocks: if self.training and use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) else: x = block(x, context, t_mod, freqs) diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py index 35f5ea3..b49235b 100644 --- a/diffsynth/models/wan_video_image_encoder.py +++ b/diffsynth/models/wan_video_image_encoder.py @@ -228,7 +228,7 @@ class QuickGELU(nn.Module): class LayerNorm(nn.LayerNorm): def forward(self, x): - return super().forward(x.float()).type_as(x) + return super().forward(x).type_as(x) class SelfAttention(nn.Module): @@ -256,15 +256,11 @@ class SelfAttention(nn.Module): """ x: [B, L, C]. """ - b, s, c, n, d = *x.size(), self.num_heads, self.head_dim - # compute query, key, value - q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + q, k, v = self.to_qkv(x).chunk(3, dim=-1) # compute attention - p = self.attn_dropout if self.training else 0.0 - x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) - x = x.reshape(b, s, c) + x = flash_attention(q, k, v, num_heads=self.num_heads) # output x = self.proj(x) @@ -371,11 +367,11 @@ class AttentionPool(nn.Module): b, s, c, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value - q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) - k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1) + k, v = self.to_kv(x).chunk(2, dim=-1) # compute attention - x = flash_attention(q, k, v, version=2) + x = flash_attention(q, k, v, num_heads=self.num_heads) x = x.reshape(b, 1, c) # output @@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module): videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward + dtype = next(iter(self.model.visual.parameters())).dtype + videos = videos.to(dtype) out = self.model.visual(videos, use_31_block=True) return out diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 01b5484..df23076 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module): target_w: target_w + hidden_states_batch.shape[4], ] += mask values = values / weight - values = values.float().clamp_(-1, 1) + values = values.clamp_(-1, 1) return values @@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module): target_w: target_w + hidden_states_batch.shape[4], ] += mask values = values / weight - values = values.float() return values def single_encode(self, video, device): video = video.to(device) x = self.model.encode(video, self.scale) - return x.float() + return x def single_decode(self, hidden_state, device): hidden_state = hidden_state.to(device) video = self.model.decode(hidden_state, self.scale) - return video.float().clamp_(-1, 1) + return video.clamp_(-1, 1) def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 45ef3b3..2f19d42 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -60,7 +60,6 @@ class WanVideoPipeline(BasePipeline): torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, - torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, }, module_config = dict( @@ -116,7 +115,7 @@ class WanVideoPipeline(BasePipeline): offload_device="cpu", onload_dtype=dtype, onload_device="cpu", - computation_dtype=self.torch_dtype, + computation_dtype=dtype, computation_device=self.device, ), ) @@ -153,17 +152,21 @@ class WanVideoPipeline(BasePipeline): def encode_image(self, image, num_frames, height, width): - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - image = self.preprocess_image(image.resize((width, height))).to(self.device) - clip_context = self.image_encoder.encode_image([image]) - msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) - msk[:, 1:] = 0 - msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) - msk = msk.transpose(1, 2)[0] - y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0] - y = torch.concat([msk, y]) - return {"clip_fea": clip_context, "y": [y]} + image = self.preprocess_image(image.resize((width, height))).to(self.device) + clip_context = self.image_encoder.encode_image([image]) + msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device) + y = y.to(dtype=self.torch_dtype, device=self.device) + return {"clip_feature": clip_context, "y": y} def tensor2video(self, frames): @@ -174,18 +177,16 @@ class WanVideoPipeline(BasePipeline): def prepare_extra_input(self, latents=None): - return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4} + return {} def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return latents def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames @@ -229,8 +230,8 @@ class WanVideoPipeline(BasePipeline): if input_video is not None: self.load_models_to_device(['vae']) input_video = self.preprocess_images(input_video) - input_video = torch.stack(input_video, dim=2) - latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device) + input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device) + latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = noise diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index c695622..f83e85e 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -113,6 +113,7 @@ class LightningModelForDataProcess(pl.LightningModule): self.pipe.device = self.device if video is not None: prompt_emb = self.pipe.encode_prompt(text) + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] data = {"latents": latents, "prompt_emb": prompt_emb} torch.save(data, path + ".tensors.pth") @@ -145,10 +146,21 @@ class TensorDataset(torch.utils.data.Dataset): class LightningModelForTrain(pl.LightningModule): - def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None): + def __init__( + self, + dit_path, + learning_rate=1e-5, + lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + pretrained_lora_path=None + ): super().__init__() model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models([dit_path]) + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -167,6 +179,7 @@ class LightningModelForTrain(pl.LightningModule): self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload def freeze_parameters(self): @@ -210,24 +223,25 @@ class LightningModelForTrain(pl.LightningModule): # Data latents = batch["latents"].to(self.device) prompt_emb = batch["prompt_emb"] - prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) # Loss + self.pipe.device = self.device noise = torch.randn_like(latents) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) extra_input = self.pipe.prepare_extra_input(latents) noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) training_target = self.pipe.scheduler.training_target(latents, noise, timestep) # Compute loss - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, - use_gradient_checkpointing=self.use_gradient_checkpointing - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) # Record log self.log("train_loss", loss, prog_bar=True) @@ -410,6 +424,12 @@ def parse_args(): action="store_true", help="Whether to use gradient checkpointing.", ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) parser.add_argument( "--train_architecture", type=str, @@ -490,6 +510,7 @@ def train(args): lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, pretrained_lora_path=args.pretrained_lora_path, ) if args.use_swanlab: @@ -510,6 +531,7 @@ def train(args): max_epochs=args.max_epochs, accelerator="gpu", devices="auto", + precision="bf16", strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, diff --git a/examples/wanvideo/wan_14b_image_to_video.py b/examples/wanvideo/wan_14b_image_to_video.py index db4d6da..91894ae 100644 --- a/examples/wanvideo/wan_14b_image_to_video.py +++ b/examples/wanvideo/wan_14b_image_to_video.py @@ -11,7 +11,7 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1- model_manager = ModelManager(device="cpu") model_manager.load_models( ["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], - torch_dtype=torch.float16, # Image Encoder is loaded with float16 + torch_dtype=torch.float32, # Image Encoder is loaded with float32 ) model_manager.load_models( [