mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1354 from mi804/low_vram_training_ds
low vram training with deepspeed zero3
This commit is contained in:
@@ -1,12 +1,32 @@
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
_HAS_DEEPSPEED = True
|
||||
except ModuleNotFoundError:
|
||||
_HAS_DEEPSPEED = False
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def create_custom_forward_use_reentrant(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def judge_args_requires_grad(*args):
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
@@ -14,6 +34,17 @@ def gradient_checkpoint_forward(
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():
|
||||
all_args = args + tuple(kwargs.values())
|
||||
if not judge_args_requires_grad(*all_args):
|
||||
# get the first grad_enabled tensor from un_checkpointed forward
|
||||
model_output = model(*args, **kwargs)
|
||||
else:
|
||||
model_output = deepspeed.checkpointing.checkpoint(
|
||||
create_custom_forward_use_reentrant(model),
|
||||
*all_args,
|
||||
)
|
||||
return model_output
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
|
||||
@@ -29,7 +29,7 @@ def launch_training_task(
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
model.to(device=accelerator.device)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
initialize_deepspeed_gradient_checkpointing(accelerator)
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
@@ -70,3 +70,19 @@ def launch_data_process_task(
|
||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||
data = model(data)
|
||||
torch.save(data, save_path)
|
||||
|
||||
|
||||
def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
|
||||
if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
|
||||
ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
if "activation_checkpointing" in ds_config:
|
||||
import deepspeed
|
||||
act_config = ds_config["activation_checkpointing"]
|
||||
deepspeed.checkpointing.configure(
|
||||
mpu_=None,
|
||||
partition_activations=act_config.get("partition_activations", False),
|
||||
checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
|
||||
contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
|
||||
)
|
||||
else:
|
||||
print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")
|
||||
|
||||
Reference in New Issue
Block a user