mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
File diff suppressed because it is too large
Load Diff
@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
|
|||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return super().forward(x.float()).type_as(x)
|
return super().forward(x).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x: [B, L, C].
|
x: [B, L, C].
|
||||||
"""
|
"""
|
||||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
p = self.attn_dropout if self.training else 0.0
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
|
||||||
x = x.reshape(b, s, c)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
|
|||||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, version=2)
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
x = x.reshape(b, 1, c)
|
x = x.reshape(b, 1, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
|
|||||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
|
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||||
|
videos = videos.to(dtype)
|
||||||
out = self.model.visual(videos, use_31_block=True)
|
out = self.model.visual(videos, use_31_block=True)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float().clamp_(-1, 1)
|
values = values.clamp_(-1, 1)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float()
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def single_encode(self, video, device):
|
def single_encode(self, video, device):
|
||||||
video = video.to(device)
|
video = video.to(device)
|
||||||
x = self.model.encode(video, self.scale)
|
x = self.model.encode(video, self.scale)
|
||||||
return x.float()
|
return x
|
||||||
|
|
||||||
|
|
||||||
def single_decode(self, hidden_state, device):
|
def single_decode(self, hidden_state, device):
|
||||||
hidden_state = hidden_state.to(device)
|
hidden_state = hidden_state.to(device)
|
||||||
video = self.model.decode(hidden_state, self.scale)
|
video = self.model.decode(hidden_state, self.scale)
|
||||||
return video.float().clamp_(-1, 1)
|
return video.clamp_(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
from ..models.wan_video_dit import RMSNorm
|
||||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||||
|
|
||||||
|
|
||||||
@@ -60,8 +60,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
torch.nn.Linear: AutoWrappedLinear,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
WanLayerNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
WanRMSNorm: AutoWrappedModule,
|
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
@@ -116,7 +115,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
offload_device="cpu",
|
offload_device="cpu",
|
||||||
onload_dtype=dtype,
|
onload_dtype=dtype,
|
||||||
onload_device="cpu",
|
onload_device="cpu",
|
||||||
computation_dtype=self.torch_dtype,
|
computation_dtype=dtype,
|
||||||
computation_device=self.device,
|
computation_device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -153,17 +152,21 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, num_frames, height, width):
|
def encode_image(self, image, num_frames, height, width):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
clip_context = self.image_encoder.encode_image([image])
|
||||||
clip_context = self.image_encoder.encode_image([image])
|
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
msk[:, 1:] = 0
|
||||||
msk[:, 1:] = 0
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
msk = msk.transpose(1, 2)[0]
|
||||||
msk = msk.transpose(1, 2)[0]
|
|
||||||
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||||
y = torch.concat([msk, y])
|
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||||
return {"clip_fea": clip_context, "y": [y]}
|
y = torch.concat([msk, y])
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
def tensor2video(self, frames):
|
def tensor2video(self, frames):
|
||||||
@@ -174,18 +177,16 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
def prepare_extra_input(self, latents=None):
|
||||||
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
@@ -224,12 +225,13 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||||
|
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||||
if input_video is not None:
|
if input_video is not None:
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
input_video = self.preprocess_images(input_video)
|
input_video = self.preprocess_images(input_video)
|
||||||
input_video = torch.stack(input_video, dim=2)
|
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
else:
|
else:
|
||||||
latents = noise
|
latents = noise
|
||||||
@@ -252,20 +254,19 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(["dit"])
|
self.load_models_to_device(["dit"])
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
|
|||||||
@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
|
|||||||
mask = mask.to(device)
|
mask = mask.to(device)
|
||||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||||
prompt_emb = self.text_encoder(ids, mask)
|
prompt_emb = self.text_encoder(ids, mask)
|
||||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
for i, v in enumerate(seq_lens):
|
||||||
|
prompt_emb[:, v:] = 0
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|||||||
@@ -155,6 +155,10 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
|||||||
--use_gradient_checkpointing
|
--use_gradient_checkpointing
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
||||||
|
|
||||||
|
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
||||||
|
|
||||||
Step 5: Test
|
Step 5: Test
|
||||||
|
|
||||||
Test LoRA:
|
Test LoRA:
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
self.pipe.device = self.device
|
self.pipe.device = self.device
|
||||||
if video is not None:
|
if video is not None:
|
||||||
prompt_emb = self.pipe.encode_prompt(text)
|
prompt_emb = self.pipe.encode_prompt(text)
|
||||||
|
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
data = {"latents": latents, "prompt_emb": prompt_emb}
|
||||||
torch.save(data, path + ".tensors.pth")
|
torch.save(data, path + ".tensors.pth")
|
||||||
@@ -145,10 +146,21 @@ class TensorDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForTrain(pl.LightningModule):
|
class LightningModelForTrain(pl.LightningModule):
|
||||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dit_path,
|
||||||
|
learning_rate=1e-5,
|
||||||
|
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||||
|
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||||
|
pretrained_lora_path=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
model_manager.load_models([dit_path])
|
if os.path.isfile(dit_path):
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
else:
|
||||||
|
dit_path = dit_path.split(",")
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
|
||||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
@@ -167,6 +179,7 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
|
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
|
||||||
|
|
||||||
def freeze_parameters(self):
|
def freeze_parameters(self):
|
||||||
@@ -210,24 +223,25 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
# Data
|
# Data
|
||||||
latents = batch["latents"].to(self.device)
|
latents = batch["latents"].to(self.device)
|
||||||
prompt_emb = batch["prompt_emb"]
|
prompt_emb = batch["prompt_emb"]
|
||||||
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
|
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
|
self.pipe.device = self.device
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
extra_input = self.pipe.prepare_extra_input(latents)
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
noise_pred = self.pipe.denoising_model()(
|
||||||
noise_pred = self.pipe.denoising_model()(
|
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||||
)
|
)
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
# Record log
|
# Record log
|
||||||
self.log("train_loss", loss, prog_bar=True)
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
@@ -410,6 +424,12 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use gradient checkpointing.",
|
help="Whether to use gradient checkpointing.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gradient_checkpointing_offload",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use gradient checkpointing offload.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_architecture",
|
"--train_architecture",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -490,6 +510,7 @@ def train(args):
|
|||||||
lora_target_modules=args.lora_target_modules,
|
lora_target_modules=args.lora_target_modules,
|
||||||
init_lora_weights=args.init_lora_weights,
|
init_lora_weights=args.init_lora_weights,
|
||||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
pretrained_lora_path=args.pretrained_lora_path,
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
)
|
)
|
||||||
if args.use_swanlab:
|
if args.use_swanlab:
|
||||||
@@ -510,6 +531,7 @@ def train(args):
|
|||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
devices="auto",
|
devices="auto",
|
||||||
|
precision="bf16",
|
||||||
strategy=args.training_strategy,
|
strategy=args.training_strategy,
|
||||||
default_root_dir=args.output_path,
|
default_root_dir=args.output_path,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
|
|||||||
model_manager = ModelManager(device="cpu")
|
model_manager = ModelManager(device="cpu")
|
||||||
model_manager.load_models(
|
model_manager.load_models(
|
||||||
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
|
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
|
||||||
torch_dtype=torch.float16, # Image Encoder is loaded with float16
|
torch_dtype=torch.float32, # Image Encoder is loaded with float32
|
||||||
)
|
)
|
||||||
model_manager.load_models(
|
model_manager.load_models(
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user