mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
5 Commits
examples-u
...
zero3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fe897883b | ||
|
|
ca9b5e64ea | ||
|
|
2070bbd925 | ||
|
|
3140199c96 | ||
|
|
4e9db263b0 |
@@ -3,14 +3,13 @@ from ..vram.disk_map import DiskMap
|
|||||||
from ..vram.layers import enable_vram_management
|
from ..vram.layers import enable_vram_management
|
||||||
from .file import load_state_dict
|
from .file import load_state_dict
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils import ContextManagers
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
config = {} if config is None else config
|
config = {} if config is None else config
|
||||||
# Why do we use `skip_model_initialization`?
|
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||||
# It skips the random initialization of model parameters,
|
|
||||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
|
||||||
with skip_model_initialization():
|
|
||||||
model = model_class(**config)
|
model = model_class(**config)
|
||||||
# What is `module_map`?
|
# What is `module_map`?
|
||||||
# This is a module mapping table for VRAM management.
|
# This is a module mapping table for VRAM management.
|
||||||
@@ -46,7 +45,14 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
state_dict = state_dict_converter(state_dict)
|
state_dict = state_dict_converter(state_dict)
|
||||||
else:
|
else:
|
||||||
state_dict = {i: state_dict[i] for i in state_dict}
|
state_dict = {i: state_dict[i] for i in state_dict}
|
||||||
model.load_state_dict(state_dict, assign=True)
|
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||||
|
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||||
|
# Loading them directly could lead to excessive GPU memory consumption.
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||||
|
_load_state_dict_into_zero3_model(model, state_dict)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
# Why do we call `to()`?
|
# Why do we call `to()`?
|
||||||
# Because some models override the behavior of `to()`,
|
# Because some models override the behavior of `to()`,
|
||||||
# especially those from libraries like Transformers.
|
# especially those from libraries like Transformers.
|
||||||
@@ -77,3 +83,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
|
|||||||
}
|
}
|
||||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_init_context(torch_dtype, device):
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
from transformers.modeling_utils import set_zero3_state
|
||||||
|
import deepspeed
|
||||||
|
# Why do we use "deepspeed.zero.Init"?
|
||||||
|
# Weight segmentation of the model can be performed on the CPU side
|
||||||
|
# and loading the segmented weights onto the computing card
|
||||||
|
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||||
|
else:
|
||||||
|
# Why do we use `skip_model_initialization`?
|
||||||
|
# It skips the random initialization of model parameters,
|
||||||
|
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||||
|
init_contexts = [skip_model_initialization()]
|
||||||
|
|
||||||
|
return init_contexts
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ class ModelLogger:
|
|||||||
|
|
||||||
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
state_dict = accelerator.get_state_dict(model)
|
|
||||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
state_dict = self.state_dict_converter(state_dict)
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
os.makedirs(self.output_path, exist_ok=True)
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
@@ -34,8 +34,8 @@ class ModelLogger:
|
|||||||
|
|
||||||
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
state_dict = accelerator.get_state_dict(model)
|
|
||||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
state_dict = self.state_dict_converter(state_dict)
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
os.makedirs(self.output_path, exist_ok=True)
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def launch_training_task(
|
|||||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
for epoch_id in range(num_epochs):
|
for epoch_id in range(num_epochs):
|
||||||
@@ -59,6 +59,7 @@ def launch_data_process_task(
|
|||||||
num_workers = args.dataset_num_workers
|
num_workers = args.dataset_num_workers
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
|
model.to(device=accelerator.device)
|
||||||
model, dataloader = accelerator.prepare(model, dataloader)
|
model, dataloader = accelerator.prepare(model, dataloader)
|
||||||
|
|
||||||
for data_id, data in enumerate(tqdm(dataloader)):
|
for data_id, data in enumerate(tqdm(dataloader)):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import math
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .wan_video_camera_controller import SimpleAdapter
|
from .wan_video_camera_controller import SimpleAdapter
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
@@ -380,26 +381,14 @@ class WanModel(torch.nn.Module):
|
|||||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
if self.training and use_gradient_checkpointing:
|
if self.training:
|
||||||
if use_gradient_checkpointing_offload:
|
x = gradient_checkpoint_forward(
|
||||||
with torch.autograd.graph.save_on_cpu():
|
block,
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
use_gradient_checkpointing,
|
||||||
create_custom_forward(block),
|
use_gradient_checkpointing_offload,
|
||||||
x, context, t_mod, freqs,
|
x, context, t_mod, freqs
|
||||||
use_reentrant=False,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||||
@@ -545,46 +546,19 @@ class WanS2VModel(torch.nn.Module):
|
|||||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block_id, block in enumerate(self.blocks):
|
for block_id, block in enumerate(self.blocks):
|
||||||
if use_gradient_checkpointing_offload:
|
x = gradient_checkpoint_forward(
|
||||||
with torch.autograd.graph.save_on_cpu():
|
block,
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
use_gradient_checkpointing,
|
||||||
create_custom_forward(block),
|
use_gradient_checkpointing_offload,
|
||||||
x,
|
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
|
||||||
context,
|
)
|
||||||
t_mod,
|
x = gradient_checkpoint_forward(
|
||||||
seq_len_x,
|
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
|
||||||
pre_compute_freqs[0],
|
use_gradient_checkpointing,
|
||||||
use_reentrant=False,
|
use_gradient_checkpointing_offload,
|
||||||
)
|
x
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
)
|
||||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
|
||||||
x,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
elif use_gradient_checkpointing:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x,
|
|
||||||
context,
|
|
||||||
t_mod,
|
|
||||||
seq_len_x,
|
|
||||||
pre_compute_freqs[0],
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
|
||||||
x,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
|
||||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
|
||||||
|
|
||||||
x = x[:, :seq_len_x]
|
x = x[:, :seq_len_x]
|
||||||
x = self.head(x, t[:-1])
|
x = self.head(x, t[:-1])
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from .wan_video_dit import DiTBlock
|
from .wan_video_dit import DiTBlock
|
||||||
|
from ..core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
class VaceWanAttentionBlock(DiTBlock):
|
class VaceWanAttentionBlock(DiTBlock):
|
||||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||||
@@ -62,26 +62,13 @@ class VaceWanModel(torch.nn.Module):
|
|||||||
dim=1) for u in c
|
dim=1) for u in c
|
||||||
])
|
])
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
for block in self.vace_blocks:
|
for block in self.vace_blocks:
|
||||||
if use_gradient_checkpointing_offload:
|
c = gradient_checkpoint_forward(
|
||||||
with torch.autograd.graph.save_on_cpu():
|
block,
|
||||||
c = torch.utils.checkpoint.checkpoint(
|
use_gradient_checkpointing,
|
||||||
create_custom_forward(block),
|
use_gradient_checkpointing_offload,
|
||||||
c, x, context, t_mod, freqs,
|
c, x, context, t_mod, freqs
|
||||||
use_reentrant=False,
|
)
|
||||||
)
|
|
||||||
elif use_gradient_checkpointing:
|
|
||||||
c = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
c, x, context, t_mod, freqs,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
c = block(c, x, context, t_mod, freqs)
|
|
||||||
hints = torch.unbind(c)[:-1]
|
hints = torch.unbind(c)[:-1]
|
||||||
return hints
|
return hints
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ class Resample(nn.Module):
|
|||||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
return x
|
return x, feat_cache, feat_idx
|
||||||
|
|
||||||
def init_weight(self, conv):
|
def init_weight(self, conv):
|
||||||
conv_weight = conv.weight
|
conv_weight = conv.weight
|
||||||
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x + h
|
return x + h, feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
|
|||||||
for module in self.downsamples:
|
for module in self.downsamples:
|
||||||
x = module(x, feat_cache, feat_idx)
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
return x + self.avg_shortcut(x_copy)
|
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
class Up_ResidualBlock(nn.Module):
|
class Up_ResidualBlock(nn.Module):
|
||||||
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
|
|||||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||||
return x_main + x_shortcut
|
return x_main + x_shortcut
|
||||||
else:
|
else:
|
||||||
return x_main
|
return x_main, feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
class Encoder3d(nn.Module):
|
class Encoder3d(nn.Module):
|
||||||
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
|
|||||||
## downsamples
|
## downsamples
|
||||||
for layer in self.downsamples:
|
for layer in self.downsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x, feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
class Encoder3d_38(nn.Module):
|
class Encoder3d_38(nn.Module):
|
||||||
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
|
|||||||
## downsamples
|
## downsamples
|
||||||
for layer in self.downsamples:
|
for layer in self.downsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
return x
|
return x, feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
class Decoder3d(nn.Module):
|
class Decoder3d(nn.Module):
|
||||||
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
|
|||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## upsamples
|
## upsamples
|
||||||
for layer in self.upsamples:
|
for layer in self.upsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x, feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
|
|||||||
|
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## upsamples
|
## upsamples
|
||||||
for layer in self.upsamples:
|
for layer in self.upsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x, feat_cache, feat_idx
|
||||||
|
|
||||||
|
|
||||||
def count_conv3d(model):
|
def count_conv3d(model):
|
||||||
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
|
|||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.encoder(x[:, :, :1, :, :],
|
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=self._enc_feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=self._enc_conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=self._enc_feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=self._enc_conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
|
|||||||
@@ -348,13 +348,12 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
|||||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||||
|
|
||||||
# Forward pass through the model
|
# Forward pass through the model
|
||||||
with torch.inference_mode():
|
output = text_encoder(
|
||||||
output = text_encoder(
|
input_ids=input_ids,
|
||||||
input_ids=input_ids,
|
attention_mask=attention_mask,
|
||||||
attention_mask=attention_mask,
|
output_hidden_states=True,
|
||||||
output_hidden_states=True,
|
use_cache=False,
|
||||||
use_cache=False,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Only use outputs from intermediate layers and stack them
|
# Only use outputs from intermediate layers and stack them
|
||||||
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
|
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
|
||||||
|
|||||||
@@ -1321,11 +1321,6 @@ def model_fn_wan_video(
|
|||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
def create_custom_forward_vap(block, vap):
|
def create_custom_forward_vap(block, vap):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return vap(block, *inputs)
|
return vap(block, *inputs)
|
||||||
@@ -1339,32 +1334,24 @@ def model_fn_wan_video(
|
|||||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward_vap(block, vap),
|
create_custom_forward_vap(block, vap),
|
||||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||||
use_reentrant=False,
|
use_reentrant=False
|
||||||
)
|
)
|
||||||
elif use_gradient_checkpointing:
|
elif use_gradient_checkpointing:
|
||||||
x, x_vap = torch.utils.checkpoint.checkpoint(
|
x, x_vap = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward_vap(block, vap),
|
create_custom_forward_vap(block, vap),
|
||||||
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
|
||||||
use_reentrant=False,
|
use_reentrant=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
|
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
|
||||||
else:
|
else:
|
||||||
if use_gradient_checkpointing_offload:
|
x = gradient_checkpoint_forward(
|
||||||
with torch.autograd.graph.save_on_cpu():
|
block,
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
use_gradient_checkpointing,
|
||||||
create_custom_forward(block),
|
use_gradient_checkpointing_offload,
|
||||||
x, context, t_mod, freqs,
|
x, context, t_mod, freqs
|
||||||
use_reentrant=False,
|
)
|
||||||
)
|
|
||||||
elif use_gradient_checkpointing:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = block(x, context, t_mod, freqs)
|
|
||||||
|
|
||||||
# VACE
|
# VACE
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
@@ -1487,32 +1474,18 @@ def model_fn_wans2v(
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
if use_gradient_checkpointing_offload:
|
x = gradient_checkpoint_forward(
|
||||||
with torch.autograd.graph.save_on_cpu():
|
block,
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
use_gradient_checkpointing,
|
||||||
create_custom_forward(block),
|
use_gradient_checkpointing_offload,
|
||||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
|
||||||
x,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
elif use_gradient_checkpointing:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
)
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
x = gradient_checkpoint_forward(
|
||||||
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
|
||||||
x,
|
use_gradient_checkpointing,
|
||||||
use_reentrant=False,
|
use_gradient_checkpointing_offload,
|
||||||
)
|
x
|
||||||
else:
|
)
|
||||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
|
||||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
|
|
||||||
|
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|||||||
get_sp_group)
|
get_sp_group)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
from ...core.device import parse_nccl_backend, parse_device_type
|
from ...core.device import parse_nccl_backend, parse_device_type
|
||||||
|
from ...core.gradient import gradient_checkpoint_forward
|
||||||
|
|
||||||
|
|
||||||
def initialize_usp(device_type):
|
def initialize_usp(device_type):
|
||||||
@@ -82,11 +83,6 @@ def usp_dit_forward(self,
|
|||||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# Context Parallel
|
# Context Parallel
|
||||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
@@ -94,20 +90,13 @@ def usp_dit_forward(self,
|
|||||||
x = chunks[get_sequence_parallel_rank()]
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
if self.training and use_gradient_checkpointing:
|
if self.training:
|
||||||
if use_gradient_checkpointing_offload:
|
x = gradient_checkpoint_forward(
|
||||||
with torch.autograd.graph.save_on_cpu():
|
block,
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
use_gradient_checkpointing,
|
||||||
create_custom_forward(block),
|
use_gradient_checkpointing_offload,
|
||||||
x, context, t_mod, freqs,
|
x, context, t_mod, freqs
|
||||||
use_reentrant=False,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||||
|
export CPU_AFFINITY_CONF=1
|
||||||
|
|
||||||
|
accelerate launch examples/flux2/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config_zero3.yaml examples/flux2/model_training/train.py \
|
||||||
|
--dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/FLUX.2-dev-LoRA-splited" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--initialize_model_on_cpu \
|
||||||
|
--task "sft:train"
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
# This script is tested on 8*910B(NPU)
|
||||||
|
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||||
|
export CPU_AFFINITY_CONF=1
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/FLUX.2-klein-9B_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--use_gradient_checkpointing
|
||||||
|
|
||||||
|
# Edit
|
||||||
|
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||||
|
# --dataset_base_path data/example_image_dataset \
|
||||||
|
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||||
|
# --data_file_keys "image,edit_image" \
|
||||||
|
# --extra_inputs "edit_image" \
|
||||||
|
# --max_pixels 1048576 \
|
||||||
|
# --dataset_repeat 50 \
|
||||||
|
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||||
|
# --learning_rate 1e-5 \
|
||||||
|
# --num_epochs 2 \
|
||||||
|
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
# --output_path "./models/train/FLUX.2-klein-9B_full" \
|
||||||
|
# --trainable_models "dit" \
|
||||||
|
# --use_gradient_checkpointing
|
||||||
@@ -85,6 +85,7 @@ def flux2_parser():
|
|||||||
parser = add_general_config(parser)
|
parser = add_general_config(parser)
|
||||||
parser = add_image_size_config(parser)
|
parser = add_image_size_config(parser)
|
||||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ if __name__ == "__main__":
|
|||||||
fp8_models=args.fp8_models,
|
fp8_models=args.fp8_models,
|
||||||
offload_models=args.offload_models,
|
offload_models=args.offload_models,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
device=accelerator.device,
|
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||||
)
|
)
|
||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
# This script was tested using zero3 and on 8*910B(NPU)
|
||||||
|
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||||
|
export CPU_AFFINITY_CONF=1
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||||
|
--data_file_keys "image,edit_image" \
|
||||||
|
--extra_inputs "edit_image" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image-Edit-2509_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--find_unused_parameters
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
Reference in New Issue
Block a user