mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'modelscope:main' into main
This commit is contained in:
@@ -150,6 +150,35 @@ class QwenImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def _enable_fp8_lora_training(self, dtype):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
from ..models.qwen_image_vae import QwenImageRMS_norm
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
|
||||
Qwen2RMSNorm: AutoWrappedModule,
|
||||
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
|
||||
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
|
||||
QwenImageRMS_norm: AutoWrappedModule,
|
||||
}
|
||||
model_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cuda",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device="cuda",
|
||||
)
|
||||
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
|
||||
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
|
||||
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
|
||||
self.vram_management_enabled = True
|
||||
if vram_limit is None:
|
||||
|
||||
@@ -338,11 +338,15 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
|
||||
@@ -555,4 +559,5 @@ def qwen_image_parser():
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||
return parser
|
||||
|
||||
@@ -17,21 +17,27 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
enable_fp8_training=False,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
||||
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
||||
|
||||
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
||||
|
||||
# Enable FP8
|
||||
if enable_fp8_training:
|
||||
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
||||
|
||||
# Reset training scheduler (do it in each training step)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
@@ -43,7 +49,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
model = self.add_lora_to_model(
|
||||
getattr(self.pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=self.pipe.torch_dtype,
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
@@ -126,6 +133,7 @@ if __name__ == "__main__":
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
enable_fp8_training=args.enable_fp8_training,
|
||||
)
|
||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
|
||||
Reference in New Issue
Block a user