mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Flux fp8 lora training (#221)
* flux fp8 lora training --------- Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
@@ -441,13 +441,12 @@ class FluxDiT(torch.nn.Module):
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self.module,input)
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
@@ -466,7 +465,11 @@ class FluxDiT(torch.nn.Module):
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
new_layer = quantized_layer.Linear(module)
|
||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||
new_layer.weight.data = module.weight.data
|
||||
if module.bias is not None:
|
||||
new_layer.bias.data = module.bias.data
|
||||
# del module
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
new_layer = quantized_layer.RMSNorm(module)
|
||||
|
||||
@@ -11,13 +11,22 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
torch_dtype=torch.float16, pretrained_weights=[],
|
||||
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out",
|
||||
state_dict_converter=None,
|
||||
state_dict_converter=None, quantize = None
|
||||
):
|
||||
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
model_manager.load_models(pretrained_weights)
|
||||
if quantize is None:
|
||||
model_manager.load_models(pretrained_weights)
|
||||
else:
|
||||
model_manager.load_models(pretrained_weights[1:])
|
||||
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
if quantize is not None:
|
||||
self.pipe.dit.quantize()
|
||||
|
||||
self.pipe.scheduler.set_timesteps(1000)
|
||||
|
||||
self.freeze_parameters()
|
||||
@@ -66,6 +75,13 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to export lora files aligned with other opensource format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["float8_e4m3fn"],
|
||||
help="Whether to use quantization when training the model, and in which format.",
|
||||
)
|
||||
parser = add_general_parsers(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -75,12 +91,13 @@ if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
model = LightningModel(
|
||||
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||
pretrained_weights=[args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_dit_path, args.pretrained_vae_path],
|
||||
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||
learning_rate=args.learning_rate,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||
)
|
||||
launch_training_task(model, args)
|
||||
|
||||
Reference in New Issue
Block a user