diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 5d9b052..88c3089 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -8,19 +8,16 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import ContextManagers -def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, - use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None): config = {} if config is None else config with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)): model = model_class(**config) # What is `module_map`? # This is a module mapping table for VRAM management. if module_map is not None: - devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], - vram_config["computation_device"]] + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] device = [d for d in devices if d != "disk"][0] - dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], - vram_config["computation_dtype"]] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": state_dict = DiskMap(path, device, torch_dtype=dtype) @@ -29,12 +26,10 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic else: state_dict = {i: state_dict[i] for i in state_dict} model.load_state_dict(state_dict, assign=True) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, - vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) else: disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, - vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) else: # Why do we use `DiskMap`? # Sometimes a model file contains multiple models, @@ -51,6 +46,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic state_dict = state_dict_converter(state_dict) else: state_dict = {i: state_dict[i] for i in state_dict} + # Why does DeepSpeed ZeRO Stage 3 need to be handled separately? + # Because at this stage, model parameters are partitioned across multiple GPUs. + # Loading them directly could lead to excessive GPU memory consumption. if is_deepspeed_zero3_enabled(): from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model _load_state_dict_into_zero3_model(model, state_dict) @@ -65,8 +63,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic return model -def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", - state_dict_converter=None, module_map=None): +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): if isinstance(path, str): path = [path] config = {} if config is None else config diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md index 08b8a35..7a72036 100644 --- a/docs/en/Model_Details/Qwen-Image.md +++ b/docs/en/Model_Details/Qwen-Image.md @@ -106,6 +106,11 @@ Special Training Scripts: * Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/) * End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) +DeepSpeed ZeRO Stage 3 Training: The Qwen-Image series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Qwen-Image model as an example, the following modifications are required: + +* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` + ## Model Inference Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). diff --git a/docs/en/Model_Details/Wan.md b/docs/en/Model_Details/Wan.md index 83141bf..d6f6380 100644 --- a/docs/en/Model_Details/Wan.md +++ b/docs/en/Model_Details/Wan.md @@ -142,6 +142,11 @@ graph LR; * Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/) * End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/) +DeepSpeed ZeRO Stage 3 Training: The Wan series models support DeepSpeed ZeRO Stage 3 training, which partitions the model across multiple GPUs. Taking full parameter training of the Wan2.1-T2V-14B model as an example, the following modifications are required: + +* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` + ## Model Inference Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models). diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md index 697438f..39bda76 100644 --- a/docs/zh/Model_Details/Qwen-Image.md +++ b/docs/zh/Model_Details/Qwen-Image.md @@ -106,6 +106,11 @@ graph LR; * 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/qwen_image/model_training/special/split_training/) * 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) +DeepSpeed ZeRO 3 训练:Qwen-Image 系列模型支持 DeepSpeed ZeRO 3 训练,将模型拆分到多个 GPU 上,以 Qwen-Image 模型的全量训练为例,需修改: + +* `--config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` + ## 模型推理 模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。 diff --git a/docs/zh/Model_Details/Wan.md b/docs/zh/Model_Details/Wan.md index b8c3032..379c640 100644 --- a/docs/zh/Model_Details/Wan.md +++ b/docs/zh/Model_Details/Wan.md @@ -143,6 +143,11 @@ graph LR; * 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/wanvideo/model_training/special/split_training/) * 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/wanvideo/model_training/special/direct_distill/) +DeepSpeed ZeRO 3 训练:Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将模型拆分到多个 GPU 上,以 Wan2.1-T2V-14B 模型的全量训练为例,需修改: + +* `--config_file examples/wanvideo/model_training/full/accelerate_config_zero3.yaml` +* `--initialize_model_on_cpu` + ## 模型推理 模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。