mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
[feature]:Add adaptation of all models to zero3
This commit is contained in:
@@ -21,6 +21,7 @@ def gradient_checkpoint_forward(
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
determinism_check="none"
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
@@ -28,6 +29,7 @@ def gradient_checkpoint_forward(
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
determinism_check="none"
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
|
||||
@@ -3,21 +3,24 @@ from ..vram.disk_map import DiskMap
|
||||
from ..vram.layers import enable_vram_management
|
||||
from .file import load_state_dict
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None,
|
||||
use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
with skip_model_initialization():
|
||||
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||
model = model_class(**config)
|
||||
# What is `module_map`?
|
||||
# This is a module mapping table for VRAM management.
|
||||
if module_map is not None:
|
||||
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"],
|
||||
vram_config["computation_device"]]
|
||||
device = [d for d in devices if d != "disk"][0]
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"],
|
||||
vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
@@ -26,10 +29,12 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None,
|
||||
vram_limit=vram_limit)
|
||||
else:
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map,
|
||||
vram_limit=vram_limit)
|
||||
else:
|
||||
# Why do we use `DiskMap`?
|
||||
# Sometimes a model file contains multiple models,
|
||||
@@ -46,7 +51,11 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
_load_state_dict_into_zero3_model(model, state_dict)
|
||||
else:
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why do we call `to()`?
|
||||
# Because some models override the behavior of `to()`,
|
||||
# especially those from libraries like Transformers.
|
||||
@@ -56,7 +65,8 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
return model
|
||||
|
||||
|
||||
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu",
|
||||
state_dict_converter=None, module_map=None):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
config = {} if config is None else config
|
||||
@@ -77,3 +87,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
|
||||
}
|
||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||
return model
|
||||
|
||||
|
||||
def get_init_context(torch_dtype, device):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.modeling_utils import set_zero3_state
|
||||
import deepspeed
|
||||
# Why do we use "deepspeed.zero.Init"?
|
||||
# Weight segmentation of the model can be performed on the CPU side
|
||||
# and loading the segmented weights onto the computing card
|
||||
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||
else:
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
init_contexts = [skip_model_initialization()]
|
||||
|
||||
return init_contexts
|
||||
|
||||
@@ -18,8 +18,8 @@ class ModelLogger:
|
||||
|
||||
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
@@ -34,8 +34,8 @@ class ModelLogger:
|
||||
|
||||
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
|
||||
@@ -27,7 +27,7 @@ def launch_training_task(
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
|
||||
model.to(device=accelerator.device)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
@@ -59,6 +59,7 @@ def launch_data_process_task(
|
||||
num_workers = args.dataset_num_workers
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
model.to(device=accelerator.device)
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
|
||||
@@ -171,7 +171,7 @@ class Resample(nn.Module):
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
@@ -298,7 +298,7 @@ class ResidualBlock(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + h, feat_cache, feat_idx
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
@@ -471,7 +471,7 @@ class Down_ResidualBlock(nn.Module):
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
@@ -511,7 +511,7 @@ class Up_ResidualBlock(nn.Module):
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
return x_main, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
@@ -586,14 +586,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -614,7 +614,7 @@ class Encoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Encoder3d_38(nn.Module):
|
||||
@@ -698,14 +698,14 @@ class Encoder3d_38(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -730,7 +730,7 @@ class Encoder3d_38(nn.Module):
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
@@ -807,14 +807,14 @@ class Decoder3d(nn.Module):
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -835,7 +835,7 @@ class Decoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
|
||||
@@ -906,14 +906,14 @@ class Decoder3d_38(nn.Module):
|
||||
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -937,7 +937,7 @@ class Decoder3d_38(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x, feat_cache, feat_idx
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
@@ -990,11 +990,11 @@ class VideoVAE_(nn.Module):
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
@@ -348,7 +348,7 @@ class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
with torch.inference_mode():
|
||||
with torch.no_grad():
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
||||
Reference in New Issue
Block a user